mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 09:26:02 +02:00
mass test approaches
This commit is contained in:
@@ -0,0 +1,336 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription, AudioStreamTrack, VideoStreamTrack
|
||||
from av import AudioFrame, VideoFrame
|
||||
|
||||
from facefusion import logger
|
||||
from facefusion.types import VisionFrame
|
||||
|
||||
BRIDGE_PORT_START : int = 8893
|
||||
AUDIO_SAMPLE_RATE : int = 48000
|
||||
|
||||
|
||||
class FramePushTrack(VideoStreamTrack):
|
||||
kind = 'video'
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._frame : Optional[VisionFrame] = None
|
||||
self._lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def push(self, vision_frame : VisionFrame) -> None:
|
||||
with self._lock:
|
||||
self._frame = vision_frame
|
||||
|
||||
async def recv(self) -> VideoFrame:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
|
||||
with self._lock:
|
||||
frame_data = self._frame
|
||||
|
||||
if frame_data is None:
|
||||
frame_data = numpy.zeros((240, 320, 3), dtype = numpy.uint8)
|
||||
|
||||
if not self._started:
|
||||
self._started = True
|
||||
logger.info('aiortc track sending first frame', __name__)
|
||||
|
||||
video_frame = VideoFrame.from_ndarray(frame_data, format = 'bgr24')
|
||||
video_frame.pts = pts
|
||||
video_frame.time_base = time_base
|
||||
return video_frame
|
||||
|
||||
|
||||
class AudioPushTrack(AudioStreamTrack):
|
||||
kind = 'audio'
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._buffer = bytearray()
|
||||
self._lock = threading.Lock()
|
||||
self._pts = 0
|
||||
self._frame_samples = 960
|
||||
|
||||
def push(self, pcm_data : bytes) -> None:
|
||||
with self._lock:
|
||||
self._buffer.extend(pcm_data)
|
||||
|
||||
if len(self._buffer) > AUDIO_SAMPLE_RATE * 4:
|
||||
self._buffer = self._buffer[-AUDIO_SAMPLE_RATE * 4:]
|
||||
|
||||
async def recv(self) -> AudioFrame:
|
||||
await asyncio.sleep(self._frame_samples / AUDIO_SAMPLE_RATE)
|
||||
needed = self._frame_samples * 2 * 2
|
||||
|
||||
with self._lock:
|
||||
if len(self._buffer) >= needed:
|
||||
chunk = bytes(self._buffer[:needed])
|
||||
del self._buffer[:needed]
|
||||
else:
|
||||
chunk = None
|
||||
|
||||
if chunk:
|
||||
pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1)
|
||||
else:
|
||||
pcm = numpy.zeros((1, self._frame_samples * 2), dtype = numpy.int16)
|
||||
|
||||
audio_frame = AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo')
|
||||
audio_frame.sample_rate = AUDIO_SAMPLE_RATE
|
||||
audio_frame.pts = self._pts
|
||||
self._pts += self._frame_samples
|
||||
return audio_frame
|
||||
|
||||
|
||||
class AiortcBridge:
|
||||
def __init__(self) -> None:
|
||||
global BRIDGE_PORT_START
|
||||
self.port = BRIDGE_PORT_START
|
||||
BRIDGE_PORT_START += 1
|
||||
self.video_track = FramePushTrack()
|
||||
self.audio_track = AudioPushTrack()
|
||||
self.pcs : list = []
|
||||
self._http_thread : Optional[threading.Thread] = None
|
||||
self._running = False
|
||||
self._has_viewer = False
|
||||
self._loop = None
|
||||
|
||||
async def start(self) -> None:
|
||||
self._running = True
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._http_thread = threading.Thread(target = self._run_http, daemon = True)
|
||||
self._http_thread.start()
|
||||
logger.info('aiortc bridge started on port ' + str(self.port), __name__)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
for pc in self.pcs:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(pc.close(), loop)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def push_frame(self, vision_frame : VisionFrame) -> None:
|
||||
self.video_track.push(vision_frame)
|
||||
|
||||
def push_audio(self, audio_data : bytes) -> None:
|
||||
self.audio_track.push(audio_data)
|
||||
|
||||
def has_viewer(self) -> bool:
|
||||
return self._has_viewer
|
||||
|
||||
def _handle_whep(self, sdp_offer : str) -> Optional[str]:
|
||||
if not self._loop:
|
||||
return None
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(self._create_pc(sdp_offer), self._loop)
|
||||
|
||||
try:
|
||||
return future.result(timeout = 10)
|
||||
except Exception as exception:
|
||||
logger.error('whep error: ' + str(exception), __name__)
|
||||
return None
|
||||
|
||||
async def _create_pc(self, sdp_offer : str) -> Optional[str]:
|
||||
pc = RTCPeerConnection()
|
||||
self.pcs.append(pc)
|
||||
pc.addTrack(self.video_track)
|
||||
pc.addTrack(self.audio_track)
|
||||
|
||||
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
|
||||
await pc.setRemoteDescription(offer)
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
self._has_viewer = True
|
||||
return pc.localDescription.sdp
|
||||
|
||||
def _run_http(self) -> None:
|
||||
bridge = self
|
||||
|
||||
class WhepHandler(BaseHTTPRequestHandler):
|
||||
def log_message(self, format, *args) -> None:
|
||||
pass
|
||||
|
||||
def do_OPTIONS(self) -> None:
|
||||
self.send_response(200)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
|
||||
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self) -> None:
|
||||
content_length = int(self.headers.get('Content-Length', 0))
|
||||
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
|
||||
answer = bridge._handle_whep(body)
|
||||
|
||||
if answer:
|
||||
self.send_response(201)
|
||||
self.send_header('Content-Type', 'application/sdp')
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.end_headers()
|
||||
self.wfile.write(answer.encode('utf-8'))
|
||||
else:
|
||||
self.send_response(500)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.end_headers()
|
||||
|
||||
server = HTTPServer(('0.0.0.0', self.port), WhepHandler)
|
||||
server.timeout = 1
|
||||
|
||||
while self._running:
|
||||
server.handle_request()
|
||||
|
||||
|
||||
class WhipAiortcBridge:
|
||||
def __init__(self) -> None:
|
||||
global BRIDGE_PORT_START
|
||||
self.port = BRIDGE_PORT_START
|
||||
BRIDGE_PORT_START += 1
|
||||
self.whip_port = BRIDGE_PORT_START
|
||||
BRIDGE_PORT_START += 1
|
||||
self._ingest_pc = None
|
||||
self._relay_track = None
|
||||
self._viewer_pcs : list = []
|
||||
self._http_thread : Optional[threading.Thread] = None
|
||||
self._running = False
|
||||
self._loop = None
|
||||
self._ingest_ready = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self._running = True
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._http_thread = threading.Thread(target = self._run_http, daemon = True)
|
||||
self._http_thread.start()
|
||||
logger.info('whip-aiortc bridge whip=' + str(self.whip_port) + ' whep=' + str(self.port), __name__)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
if self._ingest_pc:
|
||||
await self._ingest_pc.close()
|
||||
|
||||
for pc in self._viewer_pcs:
|
||||
await pc.close()
|
||||
|
||||
def get_whip_url(self) -> str:
|
||||
return 'http://localhost:' + str(self.whip_port) + '/whip'
|
||||
|
||||
def get_whep_url(self) -> str:
|
||||
return 'http://localhost:' + str(self.port) + '/whep'
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self._ingest_ready
|
||||
|
||||
def _handle_whip(self, sdp_offer : str) -> Optional[str]:
|
||||
if not self._loop:
|
||||
return None
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(self._create_ingest(sdp_offer), self._loop)
|
||||
|
||||
try:
|
||||
return future.result(timeout = 10)
|
||||
except Exception as exception:
|
||||
logger.error('whip ingest error: ' + str(exception), __name__)
|
||||
return None
|
||||
|
||||
def _handle_whep(self, sdp_offer : str) -> Optional[str]:
|
||||
if not self._loop:
|
||||
return None
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(self._create_viewer(sdp_offer), self._loop)
|
||||
|
||||
try:
|
||||
return future.result(timeout = 10)
|
||||
except Exception as exception:
|
||||
logger.error('whep error: ' + str(exception), __name__)
|
||||
return None
|
||||
|
||||
async def _create_ingest(self, sdp_offer : str) -> Optional[str]:
|
||||
from aiortc import MediaStreamTrack
|
||||
from aiortc.contrib.media import MediaRelay
|
||||
|
||||
pc = RTCPeerConnection()
|
||||
self._ingest_pc = pc
|
||||
self._relay = MediaRelay()
|
||||
|
||||
@pc.on('track')
|
||||
def on_track(track : MediaStreamTrack) -> None:
|
||||
if track.kind == 'video':
|
||||
self._relay_track = self._relay.subscribe(track)
|
||||
self._ingest_ready = True
|
||||
logger.info('whip ingest video track received', __name__)
|
||||
|
||||
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
|
||||
await pc.setRemoteDescription(offer)
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
return pc.localDescription.sdp
|
||||
|
||||
async def _create_viewer(self, sdp_offer : str) -> Optional[str]:
|
||||
pc = RTCPeerConnection()
|
||||
self._viewer_pcs.append(pc)
|
||||
|
||||
if self._relay_track:
|
||||
pc.addTrack(self._relay_track)
|
||||
|
||||
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
|
||||
await pc.setRemoteDescription(offer)
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
return pc.localDescription.sdp
|
||||
|
||||
def _run_http(self) -> None:
|
||||
bridge = self
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def log_message(self, format, *args) -> None:
|
||||
pass
|
||||
|
||||
def do_OPTIONS(self) -> None:
|
||||
self.send_response(200)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, DELETE')
|
||||
self.send_header('Access-Control-Allow-Headers', 'Content-Type, Authorization')
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self) -> None:
|
||||
content_length = int(self.headers.get('Content-Length', 0))
|
||||
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
|
||||
path = self.path
|
||||
|
||||
if '/whip' in path:
|
||||
answer = bridge._handle_whip(body)
|
||||
elif '/whep' in path:
|
||||
answer = bridge._handle_whep(body)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
if answer:
|
||||
self.send_response(201)
|
||||
self.send_header('Content-Type', 'application/sdp')
|
||||
self.send_header('Location', path)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Expose-Headers', 'Location')
|
||||
self.end_headers()
|
||||
self.wfile.write(answer.encode('utf-8'))
|
||||
else:
|
||||
self.send_response(500)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.end_headers()
|
||||
|
||||
whip_server = HTTPServer(('0.0.0.0', self.whip_port), Handler)
|
||||
whip_server.timeout = 0.5
|
||||
whep_server = HTTPServer(('0.0.0.0', self.port), Handler)
|
||||
whep_server.timeout = 0.5
|
||||
|
||||
while self._running:
|
||||
whip_server.handle_request()
|
||||
whep_server.handle_request()
|
||||
@@ -805,6 +805,9 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None:
|
||||
with lock:
|
||||
latest_frame_holder[0] = frame
|
||||
|
||||
if data[:2] != JPEG_MAGIC:
|
||||
rtc.send_audio(stream_path, data)
|
||||
|
||||
except Exception as exception:
|
||||
logger.error(str(exception), __name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,592 @@
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import os
|
||||
import threading
|
||||
import time as _time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Dict, List, Optional, TypeAlias
|
||||
|
||||
import av
|
||||
import numpy
|
||||
|
||||
from facefusion import logger
|
||||
|
||||
RtcLib : TypeAlias = ctypes.CDLL
|
||||
WHEP_PORT : int = 8892
|
||||
|
||||
lib : Optional[RtcLib] = None
|
||||
sessions : Dict[str, dict] = {}
|
||||
http_thread : Optional[threading.Thread] = None
|
||||
running : bool = False
|
||||
|
||||
RTC_NEW = 0
|
||||
RTC_CONNECTING = 1
|
||||
RTC_CONNECTED = 2
|
||||
RTC_DISCONNECTED = 3
|
||||
RTC_FAILED = 4
|
||||
RTC_CLOSED = 5
|
||||
|
||||
RTC_GATHERING_NEW = 0
|
||||
RTC_GATHERING_INPROGRESS = 1
|
||||
RTC_GATHERING_COMPLETE = 2
|
||||
|
||||
RTC_DIRECTION_SENDONLY = 0
|
||||
RTC_DIRECTION_RECVONLY = 1
|
||||
RTC_DIRECTION_SENDRECV = 2
|
||||
RTC_DIRECTION_INACTIVE = 3
|
||||
RTC_DIRECTION_UNKNOWN = 4
|
||||
|
||||
LOG_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)
|
||||
DESCRIPTION_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p)
|
||||
CANDIDATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p)
|
||||
STATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
|
||||
GATHERING_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
|
||||
TRACK_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
|
||||
MESSAGE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p)
|
||||
|
||||
|
||||
class RtcConfiguration(ctypes.Structure):
|
||||
_fields_ =\
|
||||
[
|
||||
('iceServers', ctypes.POINTER(ctypes.c_char_p)),
|
||||
('iceServersCount', ctypes.c_int),
|
||||
('proxyServer', ctypes.c_char_p),
|
||||
('bindAddress', ctypes.c_char_p),
|
||||
('certificateType', ctypes.c_int),
|
||||
('iceTransportPolicy', ctypes.c_int),
|
||||
('enableIceTcp', ctypes.c_bool),
|
||||
('enableIceUdpMux', ctypes.c_bool),
|
||||
('disableAutoNegotiation', ctypes.c_bool),
|
||||
('forceMediaTransport', ctypes.c_bool),
|
||||
('portRangeBegin', ctypes.c_ushort),
|
||||
('portRangeEnd', ctypes.c_ushort),
|
||||
('mtu', ctypes.c_int),
|
||||
('maxMessageSize', ctypes.c_int)
|
||||
]
|
||||
|
||||
|
||||
class RtcPacketizerInit(ctypes.Structure):
|
||||
_fields_ =\
|
||||
[
|
||||
('ssrc', ctypes.c_uint32),
|
||||
('cname', ctypes.c_char_p),
|
||||
('payloadType', ctypes.c_uint8),
|
||||
('clockRate', ctypes.c_uint32),
|
||||
('sequenceNumber', ctypes.c_uint16),
|
||||
('timestamp', ctypes.c_uint32),
|
||||
('maxFragmentSize', ctypes.c_uint16),
|
||||
('nalSeparator', ctypes.c_int),
|
||||
('obuPacketization', ctypes.c_int),
|
||||
('playoutDelayId', ctypes.c_uint8),
|
||||
('playoutDelayMin', ctypes.c_uint16),
|
||||
('playoutDelayMax', ctypes.c_uint16),
|
||||
('colorSpaceId', ctypes.c_uint8),
|
||||
('colorChromaSitingHorz', ctypes.c_uint8),
|
||||
('colorChromaSitingVert', ctypes.c_uint8),
|
||||
('colorRange', ctypes.c_uint8),
|
||||
('colorPrimaries', ctypes.c_uint8),
|
||||
('colorTransfer', ctypes.c_uint8),
|
||||
('colorMatrix', ctypes.c_uint8)
|
||||
]
|
||||
|
||||
|
||||
def find_library() -> Optional[str]:
|
||||
lib_path = ctypes.util.find_library('datachannel')
|
||||
|
||||
if lib_path:
|
||||
return lib_path
|
||||
|
||||
search_paths =\
|
||||
[
|
||||
'/home/henry/local/lib/libdatachannel.so',
|
||||
'/usr/local/lib/libdatachannel.so',
|
||||
'/usr/lib/libdatachannel.so',
|
||||
'/usr/lib/x86_64-linux-gnu/libdatachannel.so'
|
||||
]
|
||||
|
||||
for path in search_paths:
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_library() -> bool:
|
||||
global lib
|
||||
|
||||
lib_path = find_library()
|
||||
|
||||
if not lib_path:
|
||||
logger.warn('libdatachannel.so not found', __name__)
|
||||
return False
|
||||
|
||||
lib = ctypes.CDLL(lib_path)
|
||||
setup_prototypes()
|
||||
lib.rtcInitLogger(4, LOG_CALLBACK_TYPE(0))
|
||||
logger.info('libdatachannel loaded from ' + lib_path, __name__)
|
||||
return True
|
||||
|
||||
|
||||
def setup_prototypes() -> None:
|
||||
lib.rtcInitLogger.argtypes = [ctypes.c_int, LOG_CALLBACK_TYPE]
|
||||
lib.rtcInitLogger.restype = None
|
||||
|
||||
lib.rtcCreatePeerConnection.argtypes = [ctypes.POINTER(RtcConfiguration)]
|
||||
lib.rtcCreatePeerConnection.restype = ctypes.c_int
|
||||
|
||||
lib.rtcDeletePeerConnection.argtypes = [ctypes.c_int]
|
||||
lib.rtcDeletePeerConnection.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p]
|
||||
lib.rtcSetLocalDescription.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetRemoteDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
|
||||
lib.rtcSetRemoteDescription.restype = ctypes.c_int
|
||||
|
||||
lib.rtcGetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int]
|
||||
lib.rtcGetLocalDescription.restype = ctypes.c_int
|
||||
|
||||
lib.rtcAddTrack.argtypes = [ctypes.c_int, ctypes.c_char_p]
|
||||
lib.rtcAddTrack.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetUserPointer.argtypes = [ctypes.c_int, ctypes.c_void_p]
|
||||
lib.rtcSetUserPointer.restype = None
|
||||
|
||||
lib.rtcSetLocalDescriptionCallback.argtypes = [ctypes.c_int, DESCRIPTION_CALLBACK_TYPE]
|
||||
lib.rtcSetLocalDescriptionCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetLocalCandidateCallback.argtypes = [ctypes.c_int, CANDIDATE_CALLBACK_TYPE]
|
||||
lib.rtcSetLocalCandidateCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetStateChangeCallback.argtypes = [ctypes.c_int, STATE_CALLBACK_TYPE]
|
||||
lib.rtcSetStateChangeCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetGatheringStateChangeCallback.argtypes = [ctypes.c_int, GATHERING_CALLBACK_TYPE]
|
||||
lib.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetTrackCallback.argtypes = [ctypes.c_int, TRACK_CALLBACK_TYPE]
|
||||
lib.rtcSetTrackCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetMessageCallback.argtypes = [ctypes.c_int, MESSAGE_CALLBACK_TYPE]
|
||||
lib.rtcSetMessageCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSendMessage.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
|
||||
lib.rtcSendMessage.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetH264Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
|
||||
lib.rtcSetH264Packetizer.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetVP8Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
|
||||
lib.rtcSetVP8Packetizer.restype = ctypes.c_int
|
||||
|
||||
lib.rtcChainRtcpSrReporter.argtypes = [ctypes.c_int]
|
||||
lib.rtcChainRtcpSrReporter.restype = ctypes.c_int
|
||||
|
||||
lib.rtcChainRtcpNackResponder.argtypes = [ctypes.c_int, ctypes.c_uint]
|
||||
lib.rtcChainRtcpNackResponder.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetTrackRtpTimestamp.argtypes = [ctypes.c_int, ctypes.c_uint32]
|
||||
lib.rtcSetTrackRtpTimestamp.restype = ctypes.c_int
|
||||
|
||||
lib.rtcSetOpenCallback.argtypes = [ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)]
|
||||
lib.rtcSetOpenCallback.restype = ctypes.c_int
|
||||
|
||||
lib.rtcIsOpen.argtypes = [ctypes.c_int]
|
||||
lib.rtcIsOpen.restype = ctypes.c_bool
|
||||
|
||||
lib.rtcSetOpusPacketizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
|
||||
lib.rtcSetOpusPacketizer.restype = ctypes.c_int
|
||||
|
||||
|
||||
callback_refs : List = []
|
||||
|
||||
|
||||
def create_peer_connection() -> int:
|
||||
config = RtcConfiguration()
|
||||
config.iceServers = None
|
||||
config.iceServersCount = 0
|
||||
config.proxyServer = None
|
||||
config.bindAddress = None
|
||||
config.certificateType = 0
|
||||
config.iceTransportPolicy = 0
|
||||
config.enableIceTcp = False
|
||||
config.enableIceUdpMux = True
|
||||
config.disableAutoNegotiation = False
|
||||
config.forceMediaTransport = True
|
||||
config.portRangeBegin = 0
|
||||
config.portRangeEnd = 0
|
||||
config.mtu = 0
|
||||
config.maxMessageSize = 0
|
||||
return lib.rtcCreatePeerConnection(ctypes.byref(config))
|
||||
|
||||
|
||||
next_rtp_port : int = 16000
|
||||
|
||||
|
||||
def create_session(stream_path : str) -> None:
|
||||
sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': 0, 'rtp_fd': None}
|
||||
|
||||
|
||||
def create_rtp_session(stream_path : str) -> int:
|
||||
global next_rtp_port
|
||||
import socket as sock
|
||||
|
||||
rtp_port = next_rtp_port
|
||||
next_rtp_port += 1
|
||||
|
||||
rtp_fd = sock.socket(sock.AF_INET, sock.SOCK_DGRAM)
|
||||
rtp_fd.bind(('127.0.0.1', rtp_port))
|
||||
rtp_fd.settimeout(1.0)
|
||||
|
||||
sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': rtp_port, 'rtp_fd': rtp_fd}
|
||||
|
||||
rtp_thread = threading.Thread(target = run_rtp_forwarder, args = (stream_path,), daemon = True)
|
||||
rtp_thread.start()
|
||||
|
||||
return rtp_port
|
||||
|
||||
|
||||
def run_rtp_forwarder(stream_path : str) -> None:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return
|
||||
|
||||
rtp_fd = session.get('rtp_fd')
|
||||
|
||||
while running and session.get('rtp_fd'):
|
||||
try:
|
||||
data, addr = rtp_fd.recvfrom(262144)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
send_to_viewers(stream_path, data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
send_start_time : float = 0
|
||||
opus_encoder : Optional[av.CodecContext] = None
|
||||
audio_buffer : bytearray = bytearray()
|
||||
audio_lock : threading.Lock = threading.Lock()
|
||||
OPUS_FRAME_SAMPLES : int = 960
|
||||
|
||||
|
||||
def send_to_viewers(stream_path : str, data : bytes) -> None:
|
||||
global send_start_time
|
||||
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return
|
||||
|
||||
viewers = session.get('viewers')
|
||||
|
||||
if not viewers:
|
||||
return
|
||||
|
||||
if send_start_time == 0:
|
||||
send_start_time = _time.monotonic()
|
||||
|
||||
elapsed = _time.monotonic() - send_start_time
|
||||
timestamp = int(elapsed * 90000) & 0xFFFFFFFF
|
||||
buf = ctypes.create_string_buffer(data)
|
||||
data_len = len(data)
|
||||
|
||||
for viewer in viewers:
|
||||
if not viewer.get('connected'):
|
||||
continue
|
||||
|
||||
for track_id in viewer.get('tracks', []):
|
||||
if not lib.rtcIsOpen(track_id):
|
||||
continue
|
||||
|
||||
lib.rtcSetTrackRtpTimestamp(track_id, timestamp)
|
||||
lib.rtcSendMessage(track_id, buf, data_len)
|
||||
|
||||
|
||||
def get_opus_encoder() -> av.CodecContext:
|
||||
global opus_encoder
|
||||
|
||||
if not opus_encoder:
|
||||
opus_encoder = av.CodecContext.create('libopus', 'w')
|
||||
opus_encoder.sample_rate = 48000
|
||||
opus_encoder.layout = 'stereo'
|
||||
opus_encoder.format = av.AudioFormat('s16')
|
||||
opus_encoder.open()
|
||||
|
||||
return opus_encoder
|
||||
|
||||
|
||||
def send_audio(stream_path : str, pcm_data : bytes) -> None:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return
|
||||
|
||||
viewers = session.get('viewers')
|
||||
|
||||
if not viewers:
|
||||
return
|
||||
|
||||
with audio_lock:
|
||||
audio_buffer.extend(pcm_data)
|
||||
needed = OPUS_FRAME_SAMPLES * 2 * 2
|
||||
|
||||
while len(audio_buffer) >= needed:
|
||||
chunk = bytes(audio_buffer[:needed])
|
||||
del audio_buffer[:needed]
|
||||
|
||||
encoder = get_opus_encoder()
|
||||
pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1)
|
||||
frame = av.AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo')
|
||||
frame.sample_rate = 48000
|
||||
frame.pts = None
|
||||
|
||||
for packet in encoder.encode(frame):
|
||||
opus_data = bytes(packet)
|
||||
|
||||
for viewer in viewers:
|
||||
if not viewer.get('connected'):
|
||||
continue
|
||||
|
||||
audio_track_id = viewer.get('audio_track')
|
||||
|
||||
if not audio_track_id:
|
||||
continue
|
||||
|
||||
if not lib.rtcIsOpen(audio_track_id):
|
||||
continue
|
||||
|
||||
elapsed = _time.monotonic() - send_start_time if send_start_time > 0 else 0
|
||||
timestamp = int(elapsed * 48000) & 0xFFFFFFFF
|
||||
buf = ctypes.create_string_buffer(opus_data)
|
||||
lib.rtcSetTrackRtpTimestamp(audio_track_id, timestamp)
|
||||
lib.rtcSendMessage(audio_track_id, buf, len(opus_data))
|
||||
|
||||
|
||||
h264_au_buffer : Dict[str, bytes] = {}
|
||||
|
||||
|
||||
def send_vp8_frame(stream_path : str, frame_data : bytes) -> None:
|
||||
send_h264_frame(stream_path, frame_data)
|
||||
|
||||
|
||||
def send_h264_frame(stream_path : str, frame_data : bytes) -> None:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return
|
||||
|
||||
prev = h264_au_buffer.get(stream_path, b'')
|
||||
buf = prev + frame_data
|
||||
|
||||
au_starts = []
|
||||
i = 0
|
||||
|
||||
while i < len(buf) - 4:
|
||||
if buf[i] == 0 and buf[i + 1] == 0 and buf[i + 2] == 0 and buf[i + 3] == 1 and i + 4 < len(buf):
|
||||
nal_type = buf[i + 4] & 0x1f
|
||||
|
||||
if nal_type == 7 or nal_type == 5:
|
||||
au_starts.append(i)
|
||||
|
||||
i += 1
|
||||
|
||||
if len(au_starts) < 2:
|
||||
h264_au_buffer[stream_path] = buf
|
||||
return
|
||||
|
||||
for j in range(len(au_starts) - 1):
|
||||
au = buf[au_starts[j]:au_starts[j + 1]]
|
||||
|
||||
for viewer in session.get('viewers', []):
|
||||
tracks = viewer.get('tracks', [])
|
||||
|
||||
if tracks:
|
||||
lib.rtcSendMessage(tracks[0], au, len(au))
|
||||
|
||||
h264_au_buffer[stream_path] = buf[au_starts[-1]:]
|
||||
|
||||
|
||||
def destroy_session(stream_path : str) -> None:
|
||||
session = sessions.pop(stream_path, None)
|
||||
|
||||
if not session:
|
||||
return
|
||||
|
||||
for viewer in session.get('viewers', []):
|
||||
pc_id = viewer.get('pc')
|
||||
|
||||
if pc_id is not None:
|
||||
lib.rtcDeletePeerConnection(pc_id)
|
||||
|
||||
|
||||
def send_data(stream_path : str, data : bytes) -> None:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return
|
||||
|
||||
for viewer in session.get('viewers', []):
|
||||
for track_id in viewer.get('tracks', []):
|
||||
lib.rtcSendMessage(track_id, data, len(data))
|
||||
|
||||
|
||||
def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
if not lib:
|
||||
return None
|
||||
|
||||
pc = create_peer_connection()
|
||||
gathering_done = threading.Event()
|
||||
local_sdp_holder = [None]
|
||||
|
||||
def on_description(pc_id, sdp, type_str, user_ptr):
|
||||
local_sdp_holder[0] = sdp.decode('utf-8') if sdp else None
|
||||
|
||||
def on_candidate(pc_id, candidate, mid, user_ptr):
|
||||
pass
|
||||
|
||||
def on_gathering(pc_id, state, user_ptr):
|
||||
if state == RTC_GATHERING_COMPLETE:
|
||||
gathering_done.set()
|
||||
|
||||
viewer = {'pc': pc, 'tracks': [], 'connected': False}
|
||||
|
||||
def on_state(pc_id, state, user_ptr):
|
||||
if state == RTC_CONNECTED:
|
||||
viewer['connected'] = True
|
||||
logger.info('viewer pc connected', __name__)
|
||||
|
||||
desc_cb = DESCRIPTION_CALLBACK_TYPE(on_description)
|
||||
cand_cb = CANDIDATE_CALLBACK_TYPE(on_candidate)
|
||||
gather_cb = GATHERING_CALLBACK_TYPE(on_gathering)
|
||||
state_cb = STATE_CALLBACK_TYPE(on_state)
|
||||
callback_refs.extend([desc_cb, cand_cb, gather_cb, state_cb])
|
||||
|
||||
lib.rtcSetLocalDescriptionCallback(pc, desc_cb)
|
||||
lib.rtcSetLocalCandidateCallback(pc, cand_cb)
|
||||
lib.rtcSetGatheringStateChangeCallback(pc, gather_cb)
|
||||
lib.rtcSetStateChangeCallback(pc, state_cb)
|
||||
|
||||
video_sdp = b'm=video 9 UDP/TLS/RTP/SAVPF 96\r\na=rtpmap:96 VP8/90000\r\na=sendonly\r\na=mid:0\r\na=rtcp-mux\r\n'
|
||||
audio_sdp = b'm=audio 9 UDP/TLS/RTP/SAVPF 111\r\na=rtpmap:111 opus/48000/2\r\na=sendonly\r\na=mid:1\r\na=rtcp-mux\r\n'
|
||||
|
||||
video_track = lib.rtcAddTrack(pc, video_sdp)
|
||||
audio_track = lib.rtcAddTrack(pc, audio_sdp)
|
||||
|
||||
video_packetizer = RtcPacketizerInit()
|
||||
video_packetizer.ssrc = 42
|
||||
video_packetizer.cname = b'video'
|
||||
video_packetizer.payloadType = 96
|
||||
video_packetizer.clockRate = 90000
|
||||
video_packetizer.maxFragmentSize = 1200
|
||||
lib.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer))
|
||||
lib.rtcChainRtcpSrReporter(video_track)
|
||||
lib.rtcChainRtcpNackResponder(video_track, 512)
|
||||
|
||||
audio_packetizer = RtcPacketizerInit()
|
||||
audio_packetizer.ssrc = 43
|
||||
audio_packetizer.cname = b'audio'
|
||||
audio_packetizer.payloadType = 111
|
||||
audio_packetizer.clockRate = 48000
|
||||
lib.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer))
|
||||
lib.rtcChainRtcpSrReporter(audio_track)
|
||||
|
||||
viewer['tracks'] = [video_track]
|
||||
viewer['audio_track'] = audio_track
|
||||
session['viewers'].append(viewer)
|
||||
|
||||
lib.rtcSetRemoteDescription(pc, sdp_offer.encode('utf-8'), b'offer')
|
||||
|
||||
gathering_done.wait(timeout = 3)
|
||||
|
||||
buf_size = 16384
|
||||
buf = ctypes.create_string_buffer(buf_size)
|
||||
result = lib.rtcGetLocalDescription(pc, buf, buf_size)
|
||||
|
||||
if result > 0:
|
||||
local_sdp = buf.value.decode('utf-8')
|
||||
elif local_sdp_holder[0]:
|
||||
local_sdp = local_sdp_holder[0]
|
||||
else:
|
||||
session['viewers'].remove(viewer)
|
||||
return None
|
||||
|
||||
return local_sdp
|
||||
|
||||
|
||||
def start() -> None:
|
||||
global running, http_thread
|
||||
|
||||
if not load_library():
|
||||
return
|
||||
|
||||
running = True
|
||||
http_thread = threading.Thread(target = run_http_server, daemon = True)
|
||||
http_thread.start()
|
||||
logger.info('rtc whep server started on port ' + str(WHEP_PORT), __name__)
|
||||
|
||||
|
||||
def stop() -> None:
|
||||
global running
|
||||
|
||||
running = False
|
||||
|
||||
for stream_path in list(sessions.keys()):
|
||||
destroy_session(stream_path)
|
||||
|
||||
|
||||
def run_http_server() -> None:
|
||||
class WhepHandler(BaseHTTPRequestHandler):
|
||||
def log_message(self, format, *args) -> None:
|
||||
pass
|
||||
|
||||
def send_cors_headers(self) -> None:
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS')
|
||||
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
||||
|
||||
def do_OPTIONS(self) -> None:
|
||||
self.send_response(200)
|
||||
self.send_cors_headers()
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self) -> None:
|
||||
path = self.path
|
||||
|
||||
if not path.endswith('/whep'):
|
||||
self.send_response(404)
|
||||
self.send_cors_headers()
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
stream_path = path[1:].rsplit('/whep', 1)[0]
|
||||
content_length = int(self.headers.get('Content-Length', 0))
|
||||
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
|
||||
answer = handle_whep_offer(stream_path, body)
|
||||
|
||||
if answer:
|
||||
self.send_response(201)
|
||||
self.send_header('Content-Type', 'application/sdp')
|
||||
self.send_header('Location', path)
|
||||
self.send_cors_headers()
|
||||
self.end_headers()
|
||||
self.wfile.write(answer.encode('utf-8'))
|
||||
return
|
||||
|
||||
self.send_response(404)
|
||||
self.send_cors_headers()
|
||||
self.end_headers()
|
||||
|
||||
server = HTTPServer(('0.0.0.0', WHEP_PORT), WhepHandler)
|
||||
server.timeout = 1
|
||||
|
||||
while running:
|
||||
server.handle_request()
|
||||
@@ -0,0 +1,546 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple, TypeAlias
|
||||
|
||||
import pylibsrtp
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from OpenSSL import SSL
|
||||
|
||||
from facefusion import logger
|
||||
|
||||
SrtpSession : TypeAlias = pylibsrtp.Session
|
||||
SrtpPolicy : TypeAlias = pylibsrtp.Policy
|
||||
|
||||
WHIP_PORT : int = 8890
|
||||
ICE_UFRAG_LENGTH : int = 4
|
||||
ICE_PWD_LENGTH : int = 22
|
||||
RTP_HEADER_SIZE : int = 12
|
||||
|
||||
SRTP_PROFILES =\
|
||||
[
|
||||
{
|
||||
'name': b'SRTP_AES128_CM_SHA1_80',
|
||||
'libsrtp': SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80,
|
||||
'key_len': 16,
|
||||
'salt_len': 14
|
||||
}
|
||||
]
|
||||
|
||||
sessions : Dict[str, dict] = {}
|
||||
server_cert = None
|
||||
server_key = None
|
||||
server_fingerprint : str = ''
|
||||
udp_socket : Optional[socket.socket] = None
|
||||
http_thread : Optional[threading.Thread] = None
|
||||
udp_thread : Optional[threading.Thread] = None
|
||||
running : bool = False
|
||||
|
||||
|
||||
def generate_credentials() -> None:
|
||||
global server_cert, server_key, server_fingerprint
|
||||
|
||||
server_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
|
||||
name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, binascii.hexlify(os.urandom(16)).decode())])
|
||||
import datetime
|
||||
now = datetime.datetime.now(tz = datetime.timezone.utc)
|
||||
builder = x509.CertificateBuilder().subject_name(name).issuer_name(name).public_key(server_key.public_key()).serial_number(x509.random_serial_number()).not_valid_before(now - datetime.timedelta(days = 1)).not_valid_after(now + datetime.timedelta(days = 30))
|
||||
server_cert = builder.sign(server_key, hashes.SHA256(), default_backend())
|
||||
fp = server_cert.fingerprint(hashes.SHA256()).hex().upper()
|
||||
server_fingerprint = ':'.join(fp[i:i + 2] for i in range(0, len(fp), 2))
|
||||
|
||||
|
||||
def generate_ice_credentials() -> Tuple[str, str]:
|
||||
ufrag = binascii.hexlify(os.urandom(ICE_UFRAG_LENGTH)).decode()
|
||||
pwd = binascii.hexlify(os.urandom(ICE_PWD_LENGTH)).decode()[:ICE_PWD_LENGTH]
|
||||
return ufrag, pwd
|
||||
|
||||
|
||||
def parse_sdp_offer(sdp : str) -> dict:
|
||||
result = {'ice_ufrag': '', 'ice_pwd': '', 'fingerprint': '', 'setup': '', 'media': [], 'candidates': []}
|
||||
current_media = None
|
||||
|
||||
for line in sdp.splitlines():
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith('a=ice-ufrag:'):
|
||||
result['ice_ufrag'] = line.split(':', 1)[1]
|
||||
if line.startswith('a=ice-pwd:'):
|
||||
result['ice_pwd'] = line.split(':', 1)[1]
|
||||
if line.startswith('a=fingerprint:'):
|
||||
result['fingerprint'] = line.split(' ', 1)[1] if ' ' in line else ''
|
||||
if line.startswith('a=setup:'):
|
||||
result['setup'] = line.split(':', 1)[1]
|
||||
if line.startswith('a=candidate:'):
|
||||
result['candidates'].append(line[12:])
|
||||
if line.startswith('m='):
|
||||
parts = line[2:].split()
|
||||
current_media = {'kind': parts[0], 'port': int(parts[1]), 'profile': parts[2], 'formats': parts[3:], 'codec_lines': [], 'mid': None}
|
||||
result['media'].append(current_media)
|
||||
if current_media:
|
||||
if line.startswith('a=rtpmap:') or line.startswith('a=fmtp:') or line.startswith('a=rtcp-fb:'):
|
||||
current_media['codec_lines'].append(line)
|
||||
if line.startswith('a=mid:'):
|
||||
current_media['mid'] = line.split(':', 1)[1]
|
||||
if line.startswith('a=extmap:'):
|
||||
current_media['codec_lines'].append(line)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_sdp_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int) -> str:
|
||||
lines = []
|
||||
lines.append('v=0')
|
||||
lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1')
|
||||
lines.append('s=-')
|
||||
lines.append('t=0 0')
|
||||
|
||||
mids = []
|
||||
for i, media in enumerate(offer.get('media', [])):
|
||||
mids.append(str(i))
|
||||
|
||||
if mids:
|
||||
lines.append('a=group:BUNDLE ' + ' '.join(mids))
|
||||
|
||||
lines.append('a=ice-lite')
|
||||
|
||||
for i, media in enumerate(offer.get('media', [])):
|
||||
kind = media.get('kind')
|
||||
formats = media.get('formats', [])
|
||||
profile = media.get('profile', 'UDP/TLS/RTP/SAVPF')
|
||||
mid = media.get('mid', str(i))
|
||||
lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats))
|
||||
lines.append('c=IN IP4 127.0.0.1')
|
||||
lines.append('a=rtcp:9 IN IP4 0.0.0.0')
|
||||
lines.append('a=ice-ufrag:' + local_ufrag)
|
||||
lines.append('a=ice-pwd:' + local_pwd)
|
||||
lines.append('a=ice-options:ice2')
|
||||
lines.append('a=fingerprint:sha-256 ' + server_fingerprint)
|
||||
lines.append('a=setup:passive')
|
||||
lines.append('a=mid:' + mid)
|
||||
lines.append('a=rtcp-mux')
|
||||
lines.append('a=recvonly')
|
||||
|
||||
for codec_line in media.get('codec_lines', []):
|
||||
lines.append(codec_line)
|
||||
|
||||
lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host')
|
||||
lines.append('a=end-of-candidates')
|
||||
|
||||
return '\r\n'.join(lines) + '\r\n'
|
||||
|
||||
|
||||
def build_whep_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int, ingest_offer : dict) -> str:
|
||||
lines = []
|
||||
lines.append('v=0')
|
||||
lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1')
|
||||
lines.append('s=-')
|
||||
lines.append('t=0 0')
|
||||
|
||||
mids = []
|
||||
for i, media in enumerate(offer.get('media', [])):
|
||||
mid = media.get('mid', str(i))
|
||||
mids.append(mid)
|
||||
|
||||
if mids:
|
||||
lines.append('a=group:BUNDLE ' + ' '.join(mids))
|
||||
|
||||
lines.append('a=ice-lite')
|
||||
|
||||
for i, media in enumerate(offer.get('media', [])):
|
||||
kind = media.get('kind')
|
||||
formats = media.get('formats', [])
|
||||
profile = media.get('profile', 'UDP/TLS/RTP/SAVPF')
|
||||
mid = media.get('mid', str(i))
|
||||
lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats))
|
||||
lines.append('c=IN IP4 127.0.0.1')
|
||||
lines.append('a=rtcp:9 IN IP4 0.0.0.0')
|
||||
lines.append('a=ice-ufrag:' + local_ufrag)
|
||||
lines.append('a=ice-pwd:' + local_pwd)
|
||||
lines.append('a=ice-options:ice2')
|
||||
lines.append('a=fingerprint:sha-256 ' + server_fingerprint)
|
||||
lines.append('a=setup:passive')
|
||||
lines.append('a=mid:' + mid)
|
||||
lines.append('a=rtcp-mux')
|
||||
lines.append('a=sendonly')
|
||||
|
||||
for codec_line in media.get('codec_lines', []):
|
||||
lines.append(codec_line)
|
||||
|
||||
lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host')
|
||||
lines.append('a=end-of-candidates')
|
||||
|
||||
return '\r\n'.join(lines) + '\r\n'
|
||||
|
||||
|
||||
def create_ssl_context() -> SSL.Context:
|
||||
ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: True)
|
||||
ctx.use_certificate(server_cert)
|
||||
ctx.use_privatekey(server_key)
|
||||
ctx.set_cipher_list(b'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-SHA')
|
||||
ctx.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80')
|
||||
return ctx
|
||||
|
||||
|
||||
def create_session(stream_path : str) -> None:
|
||||
ufrag, pwd = generate_ice_credentials()
|
||||
sessions[stream_path] = {
|
||||
'ice_ufrag': ufrag,
|
||||
'ice_pwd': pwd,
|
||||
'ingest': None,
|
||||
'viewers': [],
|
||||
'ingest_offer': None,
|
||||
'rx_srtp': None,
|
||||
'tx_sessions': []
|
||||
}
|
||||
|
||||
|
||||
def destroy_session(stream_path : str) -> None:
|
||||
sessions.pop(stream_path, None)
|
||||
|
||||
|
||||
def handle_whip(stream_path : str, sdp_offer : str) -> Optional[str]:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
offer = parse_sdp_offer(sdp_offer)
|
||||
session['ingest_offer'] = offer
|
||||
local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT
|
||||
answer = build_sdp_answer(offer, session.get('ice_ufrag'), session.get('ice_pwd'), local_port)
|
||||
return answer
|
||||
|
||||
|
||||
def handle_whep(stream_path : str, sdp_offer : str) -> Optional[str]:
|
||||
session = sessions.get(stream_path)
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
offer = parse_sdp_offer(sdp_offer)
|
||||
viewer_ufrag, viewer_pwd = generate_ice_credentials()
|
||||
local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT
|
||||
ingest_offer = session.get('ingest_offer', offer)
|
||||
answer = build_whep_answer(offer, viewer_ufrag, viewer_pwd, local_port, ingest_offer)
|
||||
session['viewers'].append({'offer': offer, 'ice_ufrag': viewer_ufrag, 'ice_pwd': viewer_pwd})
|
||||
return answer
|
||||
|
||||
|
||||
def start() -> None:
|
||||
global running, udp_socket, http_thread
|
||||
|
||||
generate_credentials()
|
||||
running = True
|
||||
|
||||
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
udp_socket.bind(('0.0.0.0', WHIP_PORT))
|
||||
udp_socket.settimeout(1.0)
|
||||
|
||||
udp_thread_instance = threading.Thread(target = run_udp_loop, daemon = True)
|
||||
udp_thread_instance.start()
|
||||
|
||||
http_thread = threading.Thread(target = run_http_server, daemon = True)
|
||||
http_thread.start()
|
||||
logger.info('webrtc sfu started on port ' + str(WHIP_PORT), __name__)
|
||||
|
||||
|
||||
def stop() -> None:
|
||||
global running, udp_socket
|
||||
|
||||
running = False
|
||||
|
||||
if udp_socket:
|
||||
udp_socket.close()
|
||||
udp_socket = None
|
||||
|
||||
|
||||
dtls_connections : Dict[tuple, dict] = {}
|
||||
|
||||
|
||||
def run_udp_loop() -> None:
|
||||
while running:
|
||||
try:
|
||||
data, addr = udp_socket.recvfrom(2048)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
first_byte = data[0]
|
||||
|
||||
if first_byte == 0 or first_byte == 1:
|
||||
handle_stun(data, addr)
|
||||
if first_byte > 19 and first_byte < 64:
|
||||
handle_dtls(data, addr)
|
||||
if first_byte > 127 and first_byte < 192:
|
||||
handle_srtp(data, addr)
|
||||
|
||||
except socket.timeout:
|
||||
continue
|
||||
except Exception:
|
||||
if running:
|
||||
continue
|
||||
|
||||
|
||||
def handle_dtls(data : bytes, addr : tuple) -> None:
|
||||
conn = dtls_connections.get(addr)
|
||||
|
||||
if not conn:
|
||||
ctx = create_ssl_context()
|
||||
ssl_conn = SSL.Connection(ctx)
|
||||
ssl_conn.set_accept_state()
|
||||
conn = {'ssl': ssl_conn, 'encrypted': False, 'rx_srtp': None, 'tx_srtp': None}
|
||||
dtls_connections[addr] = conn
|
||||
|
||||
ssl_conn = conn.get('ssl')
|
||||
ssl_conn.bio_write(data)
|
||||
|
||||
try:
|
||||
if not conn.get('encrypted'):
|
||||
try:
|
||||
ssl_conn.do_handshake()
|
||||
conn['encrypted'] = True
|
||||
setup_srtp_session(conn)
|
||||
logger.info('dtls handshake complete with ' + str(addr), __name__)
|
||||
except SSL.WantReadError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
ssl_conn.recv(1500)
|
||||
except SSL.ZeroReturnError:
|
||||
pass
|
||||
except SSL.Error:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
flush_dtls(ssl_conn, addr)
|
||||
|
||||
|
||||
def flush_dtls(ssl_conn : SSL.Connection, addr : tuple) -> None:
|
||||
try:
|
||||
outdata = ssl_conn.bio_read(1500)
|
||||
|
||||
if outdata:
|
||||
udp_socket.sendto(outdata, addr)
|
||||
except SSL.Error:
|
||||
pass
|
||||
|
||||
|
||||
def setup_srtp_session(conn : dict) -> None:
|
||||
ssl_conn = conn.get('ssl')
|
||||
ssl_conn.get_selected_srtp_profile()
|
||||
key_len = 16
|
||||
salt_len = 14
|
||||
view = ssl_conn.export_keying_material(b'EXTRACTOR-dtls_srtp', 2 * (key_len + salt_len))
|
||||
server_key = view[key_len:2 * key_len] + view[2 * key_len + salt_len:]
|
||||
client_key = view[:key_len] + view[2 * key_len:2 * key_len + salt_len]
|
||||
|
||||
rx_policy = SrtpPolicy(key = client_key, ssrc_type = SrtpPolicy.SSRC_ANY_INBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80)
|
||||
rx_policy.allow_repeat_tx = True
|
||||
rx_policy.window_size = 1024
|
||||
conn['rx_srtp'] = SrtpSession(rx_policy)
|
||||
|
||||
tx_policy = SrtpPolicy(key = server_key, ssrc_type = SrtpPolicy.SSRC_ANY_OUTBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80)
|
||||
tx_policy.allow_repeat_tx = True
|
||||
tx_policy.window_size = 1024
|
||||
conn['tx_srtp'] = SrtpSession(tx_policy)
|
||||
|
||||
|
||||
def is_rtcp(data : bytes) -> bool:
|
||||
if len(data) < 2:
|
||||
return False
|
||||
pt = data[1] & 0x7F
|
||||
return 64 <= pt <= 95
|
||||
|
||||
|
||||
def handle_srtp(data : bytes, addr : tuple) -> None:
|
||||
conn = dtls_connections.get(addr)
|
||||
|
||||
if not conn or not conn.get('rx_srtp'):
|
||||
return
|
||||
|
||||
try:
|
||||
if is_rtcp(data):
|
||||
plain = conn.get('rx_srtp').unprotect_rtcp(data)
|
||||
else:
|
||||
plain = conn.get('rx_srtp').unprotect(data)
|
||||
|
||||
forward_rtp(plain, addr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def forward_rtp(data : bytes, source_addr : tuple) -> None:
|
||||
for other_addr, conn in dtls_connections.items():
|
||||
if other_addr == source_addr:
|
||||
continue
|
||||
|
||||
if not conn.get('tx_srtp'):
|
||||
continue
|
||||
|
||||
try:
|
||||
if is_rtcp(data):
|
||||
encrypted = conn.get('tx_srtp').protect_rtcp(data)
|
||||
else:
|
||||
encrypted = conn.get('tx_srtp').protect(data)
|
||||
udp_socket.sendto(encrypted, other_addr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def handle_stun(data : bytes, addr : tuple) -> None:
|
||||
if len(data) < 20:
|
||||
return
|
||||
|
||||
msg_type = struct.unpack('!H', data[0:2])[0]
|
||||
|
||||
if msg_type != 0x0001:
|
||||
return
|
||||
|
||||
msg_length = struct.unpack('!H', data[2:4])[0]
|
||||
transaction_id = data[8:20]
|
||||
|
||||
username = None
|
||||
offset = 20
|
||||
|
||||
while offset < 20 + msg_length:
|
||||
if offset + 4 > len(data):
|
||||
break
|
||||
attr_type = struct.unpack('!H', data[offset:offset + 2])[0]
|
||||
attr_length = struct.unpack('!H', data[offset + 2:offset + 4])[0]
|
||||
attr_value = data[offset + 4:offset + 4 + attr_length]
|
||||
|
||||
if attr_type == 0x0006:
|
||||
username = attr_value.decode('utf-8', errors = 'ignore')
|
||||
|
||||
padded = attr_length + (4 - attr_length % 4) % 4
|
||||
offset += 4 + padded
|
||||
|
||||
if not username:
|
||||
return
|
||||
|
||||
local_ufrag = username.split(':')[0] if ':' in username else username
|
||||
session_pwd = None
|
||||
|
||||
for session in sessions.values():
|
||||
if session.get('ice_ufrag') == local_ufrag:
|
||||
session_pwd = session.get('ice_pwd')
|
||||
break
|
||||
|
||||
for viewer in session.get('viewers', []):
|
||||
if viewer.get('ice_ufrag') == local_ufrag:
|
||||
session_pwd = viewer.get('ice_pwd')
|
||||
break
|
||||
|
||||
if session_pwd:
|
||||
break
|
||||
|
||||
if not session_pwd:
|
||||
return
|
||||
|
||||
response = build_stun_response(transaction_id, addr, session_pwd)
|
||||
udp_socket.sendto(response, addr)
|
||||
|
||||
|
||||
def build_stun_response(transaction_id : bytes, addr : tuple, password : str) -> bytes:
|
||||
import hmac
|
||||
import zlib
|
||||
|
||||
magic_cookie = 0x2112A442
|
||||
magic_bytes = struct.pack('!I', magic_cookie)
|
||||
|
||||
xport = addr[1] ^ (magic_cookie >> 16)
|
||||
ip_int = struct.unpack('!I', socket.inet_aton(addr[0]))[0]
|
||||
xip = struct.pack('!I', ip_int ^ magic_cookie)
|
||||
xor_addr_value = struct.pack('!BBH', 0, 0x01, xport) + xip
|
||||
xor_addr_attr = struct.pack('!HH', 0x0020, len(xor_addr_value)) + xor_addr_value
|
||||
|
||||
attrs_before_integrity = xor_addr_attr
|
||||
integrity_dummy_len = len(attrs_before_integrity) + 4 + 20
|
||||
header_for_hmac = struct.pack('!HH', 0x0101, integrity_dummy_len) + magic_bytes + transaction_id
|
||||
key = password.encode('utf-8')
|
||||
integrity = hmac.new(key, header_for_hmac + attrs_before_integrity, hashlib.sha1).digest()
|
||||
integrity_attr = struct.pack('!HH', 0x0008, 20) + integrity
|
||||
|
||||
attrs_before_fp = attrs_before_integrity + integrity_attr
|
||||
fp_dummy_len = len(attrs_before_fp) + 4 + 4
|
||||
header_for_fp = struct.pack('!HH', 0x0101, fp_dummy_len) + magic_bytes + transaction_id
|
||||
crc = zlib.crc32(header_for_fp + attrs_before_fp) ^ 0x5354554E
|
||||
fingerprint_attr = struct.pack('!HHI', 0x8028, 4, crc & 0xFFFFFFFF)
|
||||
|
||||
all_attrs = attrs_before_integrity + integrity_attr + fingerprint_attr
|
||||
header = struct.pack('!HH', 0x0101, len(all_attrs)) + magic_bytes + transaction_id
|
||||
return header + all_attrs
|
||||
|
||||
|
||||
def run_http_server() -> None:
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
class WhipWhepHandler(BaseHTTPRequestHandler):
|
||||
def log_message(self, format, *args) -> None:
|
||||
pass
|
||||
|
||||
def do_POST(self) -> None:
|
||||
path = self.path
|
||||
content_length = int(self.headers.get('Content-Length', 0))
|
||||
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
|
||||
|
||||
if path.endswith('/whip'):
|
||||
stream_path = path[1:].rsplit('/whip', 1)[0]
|
||||
answer = handle_whip(stream_path, body)
|
||||
|
||||
if answer:
|
||||
self.send_response(201)
|
||||
self.send_header('Content-Type', 'application/sdp')
|
||||
self.send_header('Location', path)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.end_headers()
|
||||
self.wfile.write(answer.encode('utf-8'))
|
||||
return
|
||||
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
if path.endswith('/whep'):
|
||||
stream_path = path[1:].rsplit('/whep', 1)[0]
|
||||
answer = handle_whep(stream_path, body)
|
||||
|
||||
if answer:
|
||||
self.send_response(201)
|
||||
self.send_header('Content-Type', 'application/sdp')
|
||||
self.send_header('Location', path)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.end_headers()
|
||||
self.wfile.write(answer.encode('utf-8'))
|
||||
return
|
||||
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def do_OPTIONS(self) -> None:
|
||||
self.send_response(200)
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS')
|
||||
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
||||
self.end_headers()
|
||||
|
||||
server = HTTPServer(('0.0.0.0', WHIP_PORT), WhipWhepHandler)
|
||||
server.timeout = 1
|
||||
|
||||
while running:
|
||||
server.handle_request()
|
||||
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from facefusion import logger
|
||||
|
||||
RELAY_PORT : int = 8891
|
||||
RELAY_BINARY : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'tools', 'whip_relay')
|
||||
RELAY_PROCESS : Optional[subprocess.Popen[bytes]] = None
|
||||
|
||||
|
||||
def get_whip_url(stream_path : str) -> str:
|
||||
return 'http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/whip'
|
||||
|
||||
|
||||
def get_whep_url(stream_path : str) -> str:
|
||||
return 'http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/whep'
|
||||
|
||||
|
||||
def resolve_binary() -> str:
|
||||
relay_path = shutil.which('whip_relay')
|
||||
|
||||
if relay_path:
|
||||
return relay_path
|
||||
|
||||
if os.path.isfile(RELAY_BINARY):
|
||||
return RELAY_BINARY
|
||||
return RELAY_BINARY
|
||||
|
||||
|
||||
def start() -> None:
|
||||
global RELAY_PROCESS
|
||||
|
||||
subprocess.run([ 'fuser', '-k', str(RELAY_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
|
||||
time.sleep(0.5)
|
||||
|
||||
relay_binary = resolve_binary()
|
||||
|
||||
if not os.path.isfile(relay_binary):
|
||||
logger.warn('whip_relay binary not found at ' + relay_binary + ', skipping', __name__)
|
||||
return
|
||||
|
||||
env = os.environ.copy()
|
||||
env['LD_LIBRARY_PATH'] = '/home/henry/local/lib:' + env.get('LD_LIBRARY_PATH', '')
|
||||
RELAY_PROCESS = subprocess.Popen(
|
||||
[ relay_binary, str(RELAY_PORT) ],
|
||||
env = env,
|
||||
stdout = subprocess.PIPE,
|
||||
stderr = subprocess.PIPE
|
||||
)
|
||||
logger.info('whip relay started on port ' + str(RELAY_PORT), __name__)
|
||||
|
||||
|
||||
def stop() -> None:
|
||||
global RELAY_PROCESS
|
||||
|
||||
if RELAY_PROCESS:
|
||||
RELAY_PROCESS.terminate()
|
||||
RELAY_PROCESS.wait()
|
||||
RELAY_PROCESS = None
|
||||
|
||||
|
||||
def wait_for_ready() -> bool:
|
||||
for _ in range(10):
|
||||
try:
|
||||
response = httpx.get('http://localhost:' + str(RELAY_PORT) + '/health', timeout = 1)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def is_session_ready(stream_path : str) -> bool:
|
||||
try:
|
||||
response = httpx.get('http://localhost:' + str(RELAY_PORT) + '/session/' + stream_path, timeout = 1)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def create_session(stream_path : str) -> int:
|
||||
try:
|
||||
response = httpx.post('http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/create', timeout = 5)
|
||||
|
||||
if response.status_code == 200:
|
||||
return int(response.text)
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
@@ -72,11 +72,11 @@
|
||||
.timeline { display: none; align-items: stretch; padding: 0; background: #0e0e14; border-top: 1px solid #1e1e2e; border-bottom: 1px solid #1e1e2e; flex-shrink: 0; }
|
||||
.timeline.visible { display: flex; }
|
||||
.timeline .transport { display: flex; align-items: center; gap: 2px; padding: 0 0.4rem; background: #12121a; border-right: 1px solid #1e1e2e; }
|
||||
.timeline .transport-btn { width: 28px; height: 28px; border: none; border-radius: 6px; cursor: pointer; display: flex; align-items: center; justify-content: center; background: transparent; color: #888; transition: all 0.15s; }
|
||||
.timeline .transport-btn { width: 36px; height: 36px; border: none; border-radius: 8px; cursor: pointer; display: flex; align-items: center; justify-content: center; background: transparent; color: #888; transition: all 0.15s; }
|
||||
.timeline .transport-btn:hover { background: #1e1e2e; color: #fff; }
|
||||
.timeline .transport-btn:disabled { opacity: 0.25; cursor: not-allowed; }
|
||||
.timeline .transport-btn.active { color: #00b894; }
|
||||
.timeline .transport-btn svg { width: 14px; height: 14px; fill: currentColor; }
|
||||
.timeline .transport-btn.active { color: #888; }
|
||||
.timeline .transport-btn svg { width: 18px; height: 18px; fill: currentColor; }
|
||||
.timeline .time { font-size: 0.75rem; color: #888; font-family: monospace; min-width: 60px; display: flex; align-items: center; justify-content: center; padding: 0 0.6rem; background: #12121a; border-right: 1px solid #1e1e2e; }
|
||||
.timeline .time:last-child { border-right: none; border-left: 1px solid #1e1e2e; }
|
||||
.timeline .track { flex: 1; position: relative; height: 2em; cursor: pointer; background: repeating-linear-gradient(90deg, transparent, transparent 59px, #1a1a25 59px, #1a1a25 60px); }
|
||||
Reference in New Issue
Block a user