From 021b9a15f55d14196cb6c60d95cd887b2df9d7da Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 23 Mar 2026 13:33:50 +0100 Subject: [PATCH] mass test approaches --- facefusion/aiortc_bridge.py | 336 ++++++++++++ facefusion/apis/endpoints/stream.py | 3 + facefusion/rtc.py | 592 ++++++++++++++++++++++ facefusion/webrtc_sfu.py | 546 ++++++++++++++++++++ facefusion/whip_relay.py | 99 ++++ test_whip_stream.html => test_stream.html | 6 +- 6 files changed, 1579 insertions(+), 3 deletions(-) create mode 100644 facefusion/aiortc_bridge.py create mode 100644 facefusion/rtc.py create mode 100644 facefusion/webrtc_sfu.py create mode 100644 facefusion/whip_relay.py rename test_whip_stream.html => test_stream.html (99%) diff --git a/facefusion/aiortc_bridge.py b/facefusion/aiortc_bridge.py new file mode 100644 index 00000000..bc422c86 --- /dev/null +++ b/facefusion/aiortc_bridge.py @@ -0,0 +1,336 @@ +import asyncio +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Optional + +import numpy +from aiortc import RTCPeerConnection, RTCSessionDescription, AudioStreamTrack, VideoStreamTrack +from av import AudioFrame, VideoFrame + +from facefusion import logger +from facefusion.types import VisionFrame + +BRIDGE_PORT_START : int = 8893 +AUDIO_SAMPLE_RATE : int = 48000 + + +class FramePushTrack(VideoStreamTrack): + kind = 'video' + + def __init__(self) -> None: + super().__init__() + self._frame : Optional[VisionFrame] = None + self._lock = threading.Lock() + self._started = False + + def push(self, vision_frame : VisionFrame) -> None: + with self._lock: + self._frame = vision_frame + + async def recv(self) -> VideoFrame: + pts, time_base = await self.next_timestamp() + + with self._lock: + frame_data = self._frame + + if frame_data is None: + frame_data = numpy.zeros((240, 320, 3), dtype = numpy.uint8) + + if not self._started: + self._started = True + logger.info('aiortc track sending first frame', __name__) + + video_frame = VideoFrame.from_ndarray(frame_data, format = 'bgr24') + video_frame.pts = pts + video_frame.time_base = time_base + return video_frame + + +class AudioPushTrack(AudioStreamTrack): + kind = 'audio' + + def __init__(self) -> None: + super().__init__() + self._buffer = bytearray() + self._lock = threading.Lock() + self._pts = 0 + self._frame_samples = 960 + + def push(self, pcm_data : bytes) -> None: + with self._lock: + self._buffer.extend(pcm_data) + + if len(self._buffer) > AUDIO_SAMPLE_RATE * 4: + self._buffer = self._buffer[-AUDIO_SAMPLE_RATE * 4:] + + async def recv(self) -> AudioFrame: + await asyncio.sleep(self._frame_samples / AUDIO_SAMPLE_RATE) + needed = self._frame_samples * 2 * 2 + + with self._lock: + if len(self._buffer) >= needed: + chunk = bytes(self._buffer[:needed]) + del self._buffer[:needed] + else: + chunk = None + + if chunk: + pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1) + else: + pcm = numpy.zeros((1, self._frame_samples * 2), dtype = numpy.int16) + + audio_frame = AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo') + audio_frame.sample_rate = AUDIO_SAMPLE_RATE + audio_frame.pts = self._pts + self._pts += self._frame_samples + return audio_frame + + +class AiortcBridge: + def __init__(self) -> None: + global BRIDGE_PORT_START + self.port = BRIDGE_PORT_START + BRIDGE_PORT_START += 1 + self.video_track = FramePushTrack() + self.audio_track = AudioPushTrack() + self.pcs : list = [] + self._http_thread : Optional[threading.Thread] = None + self._running = False + self._has_viewer = False + self._loop = None + + async def start(self) -> None: + self._running = True + self._loop = asyncio.get_event_loop() + self._http_thread = threading.Thread(target = self._run_http, daemon = True) + self._http_thread.start() + logger.info('aiortc bridge started on port ' + str(self.port), __name__) + + async def stop(self) -> None: + self._running = False + + for pc in self.pcs: + try: + loop = asyncio.get_event_loop() + asyncio.run_coroutine_threadsafe(pc.close(), loop) + except Exception: + pass + + def push_frame(self, vision_frame : VisionFrame) -> None: + self.video_track.push(vision_frame) + + def push_audio(self, audio_data : bytes) -> None: + self.audio_track.push(audio_data) + + def has_viewer(self) -> bool: + return self._has_viewer + + def _handle_whep(self, sdp_offer : str) -> Optional[str]: + if not self._loop: + return None + + future = asyncio.run_coroutine_threadsafe(self._create_pc(sdp_offer), self._loop) + + try: + return future.result(timeout = 10) + except Exception as exception: + logger.error('whep error: ' + str(exception), __name__) + return None + + async def _create_pc(self, sdp_offer : str) -> Optional[str]: + pc = RTCPeerConnection() + self.pcs.append(pc) + pc.addTrack(self.video_track) + pc.addTrack(self.audio_track) + + offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer') + await pc.setRemoteDescription(offer) + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + self._has_viewer = True + return pc.localDescription.sdp + + def _run_http(self) -> None: + bridge = self + + class WhepHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args) -> None: + pass + + def do_OPTIONS(self) -> None: + self.send_response(200) + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS') + self.send_header('Access-Control-Allow-Headers', 'Content-Type') + self.end_headers() + + def do_POST(self) -> None: + content_length = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(content_length).decode('utf-8') if content_length else '' + answer = bridge._handle_whep(body) + + if answer: + self.send_response(201) + self.send_header('Content-Type', 'application/sdp') + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + self.wfile.write(answer.encode('utf-8')) + else: + self.send_response(500) + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + + server = HTTPServer(('0.0.0.0', self.port), WhepHandler) + server.timeout = 1 + + while self._running: + server.handle_request() + + +class WhipAiortcBridge: + def __init__(self) -> None: + global BRIDGE_PORT_START + self.port = BRIDGE_PORT_START + BRIDGE_PORT_START += 1 + self.whip_port = BRIDGE_PORT_START + BRIDGE_PORT_START += 1 + self._ingest_pc = None + self._relay_track = None + self._viewer_pcs : list = [] + self._http_thread : Optional[threading.Thread] = None + self._running = False + self._loop = None + self._ingest_ready = False + + async def start(self) -> None: + self._running = True + self._loop = asyncio.get_event_loop() + self._http_thread = threading.Thread(target = self._run_http, daemon = True) + self._http_thread.start() + logger.info('whip-aiortc bridge whip=' + str(self.whip_port) + ' whep=' + str(self.port), __name__) + + async def stop(self) -> None: + self._running = False + + if self._ingest_pc: + await self._ingest_pc.close() + + for pc in self._viewer_pcs: + await pc.close() + + def get_whip_url(self) -> str: + return 'http://localhost:' + str(self.whip_port) + '/whip' + + def get_whep_url(self) -> str: + return 'http://localhost:' + str(self.port) + '/whep' + + def is_ready(self) -> bool: + return self._ingest_ready + + def _handle_whip(self, sdp_offer : str) -> Optional[str]: + if not self._loop: + return None + + future = asyncio.run_coroutine_threadsafe(self._create_ingest(sdp_offer), self._loop) + + try: + return future.result(timeout = 10) + except Exception as exception: + logger.error('whip ingest error: ' + str(exception), __name__) + return None + + def _handle_whep(self, sdp_offer : str) -> Optional[str]: + if not self._loop: + return None + + future = asyncio.run_coroutine_threadsafe(self._create_viewer(sdp_offer), self._loop) + + try: + return future.result(timeout = 10) + except Exception as exception: + logger.error('whep error: ' + str(exception), __name__) + return None + + async def _create_ingest(self, sdp_offer : str) -> Optional[str]: + from aiortc import MediaStreamTrack + from aiortc.contrib.media import MediaRelay + + pc = RTCPeerConnection() + self._ingest_pc = pc + self._relay = MediaRelay() + + @pc.on('track') + def on_track(track : MediaStreamTrack) -> None: + if track.kind == 'video': + self._relay_track = self._relay.subscribe(track) + self._ingest_ready = True + logger.info('whip ingest video track received', __name__) + + offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer') + await pc.setRemoteDescription(offer) + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + return pc.localDescription.sdp + + async def _create_viewer(self, sdp_offer : str) -> Optional[str]: + pc = RTCPeerConnection() + self._viewer_pcs.append(pc) + + if self._relay_track: + pc.addTrack(self._relay_track) + + offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer') + await pc.setRemoteDescription(offer) + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + return pc.localDescription.sdp + + def _run_http(self) -> None: + bridge = self + + class Handler(BaseHTTPRequestHandler): + def log_message(self, format, *args) -> None: + pass + + def do_OPTIONS(self) -> None: + self.send_response(200) + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, DELETE') + self.send_header('Access-Control-Allow-Headers', 'Content-Type, Authorization') + self.end_headers() + + def do_POST(self) -> None: + content_length = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(content_length).decode('utf-8') if content_length else '' + path = self.path + + if '/whip' in path: + answer = bridge._handle_whip(body) + elif '/whep' in path: + answer = bridge._handle_whep(body) + else: + self.send_response(404) + self.end_headers() + return + + if answer: + self.send_response(201) + self.send_header('Content-Type', 'application/sdp') + self.send_header('Location', path) + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Expose-Headers', 'Location') + self.end_headers() + self.wfile.write(answer.encode('utf-8')) + else: + self.send_response(500) + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + + whip_server = HTTPServer(('0.0.0.0', self.whip_port), Handler) + whip_server.timeout = 0.5 + whep_server = HTTPServer(('0.0.0.0', self.port), Handler) + whep_server.timeout = 0.5 + + while self._running: + whip_server.handle_request() + whep_server.handle_request() diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index da87e759..24cca8a6 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -805,6 +805,9 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None: with lock: latest_frame_holder[0] = frame + if data[:2] != JPEG_MAGIC: + rtc.send_audio(stream_path, data) + except Exception as exception: logger.error(str(exception), __name__) diff --git a/facefusion/rtc.py b/facefusion/rtc.py new file mode 100644 index 00000000..7379046d --- /dev/null +++ b/facefusion/rtc.py @@ -0,0 +1,592 @@ +import ctypes +import ctypes.util +import os +import threading +import time as _time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Dict, List, Optional, TypeAlias + +import av +import numpy + +from facefusion import logger + +RtcLib : TypeAlias = ctypes.CDLL +WHEP_PORT : int = 8892 + +lib : Optional[RtcLib] = None +sessions : Dict[str, dict] = {} +http_thread : Optional[threading.Thread] = None +running : bool = False + +RTC_NEW = 0 +RTC_CONNECTING = 1 +RTC_CONNECTED = 2 +RTC_DISCONNECTED = 3 +RTC_FAILED = 4 +RTC_CLOSED = 5 + +RTC_GATHERING_NEW = 0 +RTC_GATHERING_INPROGRESS = 1 +RTC_GATHERING_COMPLETE = 2 + +RTC_DIRECTION_SENDONLY = 0 +RTC_DIRECTION_RECVONLY = 1 +RTC_DIRECTION_SENDRECV = 2 +RTC_DIRECTION_INACTIVE = 3 +RTC_DIRECTION_UNKNOWN = 4 + +LOG_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) +DESCRIPTION_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p) +CANDIDATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p) +STATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) +GATHERING_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) +TRACK_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) +MESSAGE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) + + +class RtcConfiguration(ctypes.Structure): + _fields_ =\ + [ + ('iceServers', ctypes.POINTER(ctypes.c_char_p)), + ('iceServersCount', ctypes.c_int), + ('proxyServer', ctypes.c_char_p), + ('bindAddress', ctypes.c_char_p), + ('certificateType', ctypes.c_int), + ('iceTransportPolicy', ctypes.c_int), + ('enableIceTcp', ctypes.c_bool), + ('enableIceUdpMux', ctypes.c_bool), + ('disableAutoNegotiation', ctypes.c_bool), + ('forceMediaTransport', ctypes.c_bool), + ('portRangeBegin', ctypes.c_ushort), + ('portRangeEnd', ctypes.c_ushort), + ('mtu', ctypes.c_int), + ('maxMessageSize', ctypes.c_int) + ] + + +class RtcPacketizerInit(ctypes.Structure): + _fields_ =\ + [ + ('ssrc', ctypes.c_uint32), + ('cname', ctypes.c_char_p), + ('payloadType', ctypes.c_uint8), + ('clockRate', ctypes.c_uint32), + ('sequenceNumber', ctypes.c_uint16), + ('timestamp', ctypes.c_uint32), + ('maxFragmentSize', ctypes.c_uint16), + ('nalSeparator', ctypes.c_int), + ('obuPacketization', ctypes.c_int), + ('playoutDelayId', ctypes.c_uint8), + ('playoutDelayMin', ctypes.c_uint16), + ('playoutDelayMax', ctypes.c_uint16), + ('colorSpaceId', ctypes.c_uint8), + ('colorChromaSitingHorz', ctypes.c_uint8), + ('colorChromaSitingVert', ctypes.c_uint8), + ('colorRange', ctypes.c_uint8), + ('colorPrimaries', ctypes.c_uint8), + ('colorTransfer', ctypes.c_uint8), + ('colorMatrix', ctypes.c_uint8) + ] + + +def find_library() -> Optional[str]: + lib_path = ctypes.util.find_library('datachannel') + + if lib_path: + return lib_path + + search_paths =\ + [ + '/home/henry/local/lib/libdatachannel.so', + '/usr/local/lib/libdatachannel.so', + '/usr/lib/libdatachannel.so', + '/usr/lib/x86_64-linux-gnu/libdatachannel.so' + ] + + for path in search_paths: + if os.path.isfile(path): + return path + + return None + + +def load_library() -> bool: + global lib + + lib_path = find_library() + + if not lib_path: + logger.warn('libdatachannel.so not found', __name__) + return False + + lib = ctypes.CDLL(lib_path) + setup_prototypes() + lib.rtcInitLogger(4, LOG_CALLBACK_TYPE(0)) + logger.info('libdatachannel loaded from ' + lib_path, __name__) + return True + + +def setup_prototypes() -> None: + lib.rtcInitLogger.argtypes = [ctypes.c_int, LOG_CALLBACK_TYPE] + lib.rtcInitLogger.restype = None + + lib.rtcCreatePeerConnection.argtypes = [ctypes.POINTER(RtcConfiguration)] + lib.rtcCreatePeerConnection.restype = ctypes.c_int + + lib.rtcDeletePeerConnection.argtypes = [ctypes.c_int] + lib.rtcDeletePeerConnection.restype = ctypes.c_int + + lib.rtcSetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p] + lib.rtcSetLocalDescription.restype = ctypes.c_int + + lib.rtcSetRemoteDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p] + lib.rtcSetRemoteDescription.restype = ctypes.c_int + + lib.rtcGetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int] + lib.rtcGetLocalDescription.restype = ctypes.c_int + + lib.rtcAddTrack.argtypes = [ctypes.c_int, ctypes.c_char_p] + lib.rtcAddTrack.restype = ctypes.c_int + + lib.rtcSetUserPointer.argtypes = [ctypes.c_int, ctypes.c_void_p] + lib.rtcSetUserPointer.restype = None + + lib.rtcSetLocalDescriptionCallback.argtypes = [ctypes.c_int, DESCRIPTION_CALLBACK_TYPE] + lib.rtcSetLocalDescriptionCallback.restype = ctypes.c_int + + lib.rtcSetLocalCandidateCallback.argtypes = [ctypes.c_int, CANDIDATE_CALLBACK_TYPE] + lib.rtcSetLocalCandidateCallback.restype = ctypes.c_int + + lib.rtcSetStateChangeCallback.argtypes = [ctypes.c_int, STATE_CALLBACK_TYPE] + lib.rtcSetStateChangeCallback.restype = ctypes.c_int + + lib.rtcSetGatheringStateChangeCallback.argtypes = [ctypes.c_int, GATHERING_CALLBACK_TYPE] + lib.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int + + lib.rtcSetTrackCallback.argtypes = [ctypes.c_int, TRACK_CALLBACK_TYPE] + lib.rtcSetTrackCallback.restype = ctypes.c_int + + lib.rtcSetMessageCallback.argtypes = [ctypes.c_int, MESSAGE_CALLBACK_TYPE] + lib.rtcSetMessageCallback.restype = ctypes.c_int + + lib.rtcSendMessage.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int] + lib.rtcSendMessage.restype = ctypes.c_int + + lib.rtcSetH264Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)] + lib.rtcSetH264Packetizer.restype = ctypes.c_int + + lib.rtcSetVP8Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)] + lib.rtcSetVP8Packetizer.restype = ctypes.c_int + + lib.rtcChainRtcpSrReporter.argtypes = [ctypes.c_int] + lib.rtcChainRtcpSrReporter.restype = ctypes.c_int + + lib.rtcChainRtcpNackResponder.argtypes = [ctypes.c_int, ctypes.c_uint] + lib.rtcChainRtcpNackResponder.restype = ctypes.c_int + + lib.rtcSetTrackRtpTimestamp.argtypes = [ctypes.c_int, ctypes.c_uint32] + lib.rtcSetTrackRtpTimestamp.restype = ctypes.c_int + + lib.rtcSetOpenCallback.argtypes = [ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)] + lib.rtcSetOpenCallback.restype = ctypes.c_int + + lib.rtcIsOpen.argtypes = [ctypes.c_int] + lib.rtcIsOpen.restype = ctypes.c_bool + + lib.rtcSetOpusPacketizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)] + lib.rtcSetOpusPacketizer.restype = ctypes.c_int + + +callback_refs : List = [] + + +def create_peer_connection() -> int: + config = RtcConfiguration() + config.iceServers = None + config.iceServersCount = 0 + config.proxyServer = None + config.bindAddress = None + config.certificateType = 0 + config.iceTransportPolicy = 0 + config.enableIceTcp = False + config.enableIceUdpMux = True + config.disableAutoNegotiation = False + config.forceMediaTransport = True + config.portRangeBegin = 0 + config.portRangeEnd = 0 + config.mtu = 0 + config.maxMessageSize = 0 + return lib.rtcCreatePeerConnection(ctypes.byref(config)) + + +next_rtp_port : int = 16000 + + +def create_session(stream_path : str) -> None: + sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': 0, 'rtp_fd': None} + + +def create_rtp_session(stream_path : str) -> int: + global next_rtp_port + import socket as sock + + rtp_port = next_rtp_port + next_rtp_port += 1 + + rtp_fd = sock.socket(sock.AF_INET, sock.SOCK_DGRAM) + rtp_fd.bind(('127.0.0.1', rtp_port)) + rtp_fd.settimeout(1.0) + + sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': rtp_port, 'rtp_fd': rtp_fd} + + rtp_thread = threading.Thread(target = run_rtp_forwarder, args = (stream_path,), daemon = True) + rtp_thread.start() + + return rtp_port + + +def run_rtp_forwarder(stream_path : str) -> None: + session = sessions.get(stream_path) + + if not session: + return + + rtp_fd = session.get('rtp_fd') + + while running and session.get('rtp_fd'): + try: + data, addr = rtp_fd.recvfrom(262144) + + if not data: + continue + + send_to_viewers(stream_path, data) + except Exception: + continue + + +send_start_time : float = 0 +opus_encoder : Optional[av.CodecContext] = None +audio_buffer : bytearray = bytearray() +audio_lock : threading.Lock = threading.Lock() +OPUS_FRAME_SAMPLES : int = 960 + + +def send_to_viewers(stream_path : str, data : bytes) -> None: + global send_start_time + + session = sessions.get(stream_path) + + if not session: + return + + viewers = session.get('viewers') + + if not viewers: + return + + if send_start_time == 0: + send_start_time = _time.monotonic() + + elapsed = _time.monotonic() - send_start_time + timestamp = int(elapsed * 90000) & 0xFFFFFFFF + buf = ctypes.create_string_buffer(data) + data_len = len(data) + + for viewer in viewers: + if not viewer.get('connected'): + continue + + for track_id in viewer.get('tracks', []): + if not lib.rtcIsOpen(track_id): + continue + + lib.rtcSetTrackRtpTimestamp(track_id, timestamp) + lib.rtcSendMessage(track_id, buf, data_len) + + +def get_opus_encoder() -> av.CodecContext: + global opus_encoder + + if not opus_encoder: + opus_encoder = av.CodecContext.create('libopus', 'w') + opus_encoder.sample_rate = 48000 + opus_encoder.layout = 'stereo' + opus_encoder.format = av.AudioFormat('s16') + opus_encoder.open() + + return opus_encoder + + +def send_audio(stream_path : str, pcm_data : bytes) -> None: + session = sessions.get(stream_path) + + if not session: + return + + viewers = session.get('viewers') + + if not viewers: + return + + with audio_lock: + audio_buffer.extend(pcm_data) + needed = OPUS_FRAME_SAMPLES * 2 * 2 + + while len(audio_buffer) >= needed: + chunk = bytes(audio_buffer[:needed]) + del audio_buffer[:needed] + + encoder = get_opus_encoder() + pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1) + frame = av.AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo') + frame.sample_rate = 48000 + frame.pts = None + + for packet in encoder.encode(frame): + opus_data = bytes(packet) + + for viewer in viewers: + if not viewer.get('connected'): + continue + + audio_track_id = viewer.get('audio_track') + + if not audio_track_id: + continue + + if not lib.rtcIsOpen(audio_track_id): + continue + + elapsed = _time.monotonic() - send_start_time if send_start_time > 0 else 0 + timestamp = int(elapsed * 48000) & 0xFFFFFFFF + buf = ctypes.create_string_buffer(opus_data) + lib.rtcSetTrackRtpTimestamp(audio_track_id, timestamp) + lib.rtcSendMessage(audio_track_id, buf, len(opus_data)) + + +h264_au_buffer : Dict[str, bytes] = {} + + +def send_vp8_frame(stream_path : str, frame_data : bytes) -> None: + send_h264_frame(stream_path, frame_data) + + +def send_h264_frame(stream_path : str, frame_data : bytes) -> None: + session = sessions.get(stream_path) + + if not session: + return + + prev = h264_au_buffer.get(stream_path, b'') + buf = prev + frame_data + + au_starts = [] + i = 0 + + while i < len(buf) - 4: + if buf[i] == 0 and buf[i + 1] == 0 and buf[i + 2] == 0 and buf[i + 3] == 1 and i + 4 < len(buf): + nal_type = buf[i + 4] & 0x1f + + if nal_type == 7 or nal_type == 5: + au_starts.append(i) + + i += 1 + + if len(au_starts) < 2: + h264_au_buffer[stream_path] = buf + return + + for j in range(len(au_starts) - 1): + au = buf[au_starts[j]:au_starts[j + 1]] + + for viewer in session.get('viewers', []): + tracks = viewer.get('tracks', []) + + if tracks: + lib.rtcSendMessage(tracks[0], au, len(au)) + + h264_au_buffer[stream_path] = buf[au_starts[-1]:] + + +def destroy_session(stream_path : str) -> None: + session = sessions.pop(stream_path, None) + + if not session: + return + + for viewer in session.get('viewers', []): + pc_id = viewer.get('pc') + + if pc_id is not None: + lib.rtcDeletePeerConnection(pc_id) + + +def send_data(stream_path : str, data : bytes) -> None: + session = sessions.get(stream_path) + + if not session: + return + + for viewer in session.get('viewers', []): + for track_id in viewer.get('tracks', []): + lib.rtcSendMessage(track_id, data, len(data)) + + +def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]: + session = sessions.get(stream_path) + + if not session: + return None + + if not lib: + return None + + pc = create_peer_connection() + gathering_done = threading.Event() + local_sdp_holder = [None] + + def on_description(pc_id, sdp, type_str, user_ptr): + local_sdp_holder[0] = sdp.decode('utf-8') if sdp else None + + def on_candidate(pc_id, candidate, mid, user_ptr): + pass + + def on_gathering(pc_id, state, user_ptr): + if state == RTC_GATHERING_COMPLETE: + gathering_done.set() + + viewer = {'pc': pc, 'tracks': [], 'connected': False} + + def on_state(pc_id, state, user_ptr): + if state == RTC_CONNECTED: + viewer['connected'] = True + logger.info('viewer pc connected', __name__) + + desc_cb = DESCRIPTION_CALLBACK_TYPE(on_description) + cand_cb = CANDIDATE_CALLBACK_TYPE(on_candidate) + gather_cb = GATHERING_CALLBACK_TYPE(on_gathering) + state_cb = STATE_CALLBACK_TYPE(on_state) + callback_refs.extend([desc_cb, cand_cb, gather_cb, state_cb]) + + lib.rtcSetLocalDescriptionCallback(pc, desc_cb) + lib.rtcSetLocalCandidateCallback(pc, cand_cb) + lib.rtcSetGatheringStateChangeCallback(pc, gather_cb) + lib.rtcSetStateChangeCallback(pc, state_cb) + + video_sdp = b'm=video 9 UDP/TLS/RTP/SAVPF 96\r\na=rtpmap:96 VP8/90000\r\na=sendonly\r\na=mid:0\r\na=rtcp-mux\r\n' + audio_sdp = b'm=audio 9 UDP/TLS/RTP/SAVPF 111\r\na=rtpmap:111 opus/48000/2\r\na=sendonly\r\na=mid:1\r\na=rtcp-mux\r\n' + + video_track = lib.rtcAddTrack(pc, video_sdp) + audio_track = lib.rtcAddTrack(pc, audio_sdp) + + video_packetizer = RtcPacketizerInit() + video_packetizer.ssrc = 42 + video_packetizer.cname = b'video' + video_packetizer.payloadType = 96 + video_packetizer.clockRate = 90000 + video_packetizer.maxFragmentSize = 1200 + lib.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer)) + lib.rtcChainRtcpSrReporter(video_track) + lib.rtcChainRtcpNackResponder(video_track, 512) + + audio_packetizer = RtcPacketizerInit() + audio_packetizer.ssrc = 43 + audio_packetizer.cname = b'audio' + audio_packetizer.payloadType = 111 + audio_packetizer.clockRate = 48000 + lib.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer)) + lib.rtcChainRtcpSrReporter(audio_track) + + viewer['tracks'] = [video_track] + viewer['audio_track'] = audio_track + session['viewers'].append(viewer) + + lib.rtcSetRemoteDescription(pc, sdp_offer.encode('utf-8'), b'offer') + + gathering_done.wait(timeout = 3) + + buf_size = 16384 + buf = ctypes.create_string_buffer(buf_size) + result = lib.rtcGetLocalDescription(pc, buf, buf_size) + + if result > 0: + local_sdp = buf.value.decode('utf-8') + elif local_sdp_holder[0]: + local_sdp = local_sdp_holder[0] + else: + session['viewers'].remove(viewer) + return None + + return local_sdp + + +def start() -> None: + global running, http_thread + + if not load_library(): + return + + running = True + http_thread = threading.Thread(target = run_http_server, daemon = True) + http_thread.start() + logger.info('rtc whep server started on port ' + str(WHEP_PORT), __name__) + + +def stop() -> None: + global running + + running = False + + for stream_path in list(sessions.keys()): + destroy_session(stream_path) + + +def run_http_server() -> None: + class WhepHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args) -> None: + pass + + def send_cors_headers(self) -> None: + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS') + self.send_header('Access-Control-Allow-Headers', 'Content-Type') + + def do_OPTIONS(self) -> None: + self.send_response(200) + self.send_cors_headers() + self.end_headers() + + def do_POST(self) -> None: + path = self.path + + if not path.endswith('/whep'): + self.send_response(404) + self.send_cors_headers() + self.end_headers() + return + + stream_path = path[1:].rsplit('/whep', 1)[0] + content_length = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(content_length).decode('utf-8') if content_length else '' + answer = handle_whep_offer(stream_path, body) + + if answer: + self.send_response(201) + self.send_header('Content-Type', 'application/sdp') + self.send_header('Location', path) + self.send_cors_headers() + self.end_headers() + self.wfile.write(answer.encode('utf-8')) + return + + self.send_response(404) + self.send_cors_headers() + self.end_headers() + + server = HTTPServer(('0.0.0.0', WHEP_PORT), WhepHandler) + server.timeout = 1 + + while running: + server.handle_request() diff --git a/facefusion/webrtc_sfu.py b/facefusion/webrtc_sfu.py new file mode 100644 index 00000000..55e2f057 --- /dev/null +++ b/facefusion/webrtc_sfu.py @@ -0,0 +1,546 @@ +import binascii +import hashlib +import os +import socket +import struct +import threading +import time +from typing import Dict, Optional, Tuple, TypeAlias + +import pylibsrtp +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from OpenSSL import SSL + +from facefusion import logger + +SrtpSession : TypeAlias = pylibsrtp.Session +SrtpPolicy : TypeAlias = pylibsrtp.Policy + +WHIP_PORT : int = 8890 +ICE_UFRAG_LENGTH : int = 4 +ICE_PWD_LENGTH : int = 22 +RTP_HEADER_SIZE : int = 12 + +SRTP_PROFILES =\ +[ + { + 'name': b'SRTP_AES128_CM_SHA1_80', + 'libsrtp': SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80, + 'key_len': 16, + 'salt_len': 14 + } +] + +sessions : Dict[str, dict] = {} +server_cert = None +server_key = None +server_fingerprint : str = '' +udp_socket : Optional[socket.socket] = None +http_thread : Optional[threading.Thread] = None +udp_thread : Optional[threading.Thread] = None +running : bool = False + + +def generate_credentials() -> None: + global server_cert, server_key, server_fingerprint + + server_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, binascii.hexlify(os.urandom(16)).decode())]) + import datetime + now = datetime.datetime.now(tz = datetime.timezone.utc) + builder = x509.CertificateBuilder().subject_name(name).issuer_name(name).public_key(server_key.public_key()).serial_number(x509.random_serial_number()).not_valid_before(now - datetime.timedelta(days = 1)).not_valid_after(now + datetime.timedelta(days = 30)) + server_cert = builder.sign(server_key, hashes.SHA256(), default_backend()) + fp = server_cert.fingerprint(hashes.SHA256()).hex().upper() + server_fingerprint = ':'.join(fp[i:i + 2] for i in range(0, len(fp), 2)) + + +def generate_ice_credentials() -> Tuple[str, str]: + ufrag = binascii.hexlify(os.urandom(ICE_UFRAG_LENGTH)).decode() + pwd = binascii.hexlify(os.urandom(ICE_PWD_LENGTH)).decode()[:ICE_PWD_LENGTH] + return ufrag, pwd + + +def parse_sdp_offer(sdp : str) -> dict: + result = {'ice_ufrag': '', 'ice_pwd': '', 'fingerprint': '', 'setup': '', 'media': [], 'candidates': []} + current_media = None + + for line in sdp.splitlines(): + line = line.strip() + + if line.startswith('a=ice-ufrag:'): + result['ice_ufrag'] = line.split(':', 1)[1] + if line.startswith('a=ice-pwd:'): + result['ice_pwd'] = line.split(':', 1)[1] + if line.startswith('a=fingerprint:'): + result['fingerprint'] = line.split(' ', 1)[1] if ' ' in line else '' + if line.startswith('a=setup:'): + result['setup'] = line.split(':', 1)[1] + if line.startswith('a=candidate:'): + result['candidates'].append(line[12:]) + if line.startswith('m='): + parts = line[2:].split() + current_media = {'kind': parts[0], 'port': int(parts[1]), 'profile': parts[2], 'formats': parts[3:], 'codec_lines': [], 'mid': None} + result['media'].append(current_media) + if current_media: + if line.startswith('a=rtpmap:') or line.startswith('a=fmtp:') or line.startswith('a=rtcp-fb:'): + current_media['codec_lines'].append(line) + if line.startswith('a=mid:'): + current_media['mid'] = line.split(':', 1)[1] + if line.startswith('a=extmap:'): + current_media['codec_lines'].append(line) + + return result + + +def build_sdp_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int) -> str: + lines = [] + lines.append('v=0') + lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1') + lines.append('s=-') + lines.append('t=0 0') + + mids = [] + for i, media in enumerate(offer.get('media', [])): + mids.append(str(i)) + + if mids: + lines.append('a=group:BUNDLE ' + ' '.join(mids)) + + lines.append('a=ice-lite') + + for i, media in enumerate(offer.get('media', [])): + kind = media.get('kind') + formats = media.get('formats', []) + profile = media.get('profile', 'UDP/TLS/RTP/SAVPF') + mid = media.get('mid', str(i)) + lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats)) + lines.append('c=IN IP4 127.0.0.1') + lines.append('a=rtcp:9 IN IP4 0.0.0.0') + lines.append('a=ice-ufrag:' + local_ufrag) + lines.append('a=ice-pwd:' + local_pwd) + lines.append('a=ice-options:ice2') + lines.append('a=fingerprint:sha-256 ' + server_fingerprint) + lines.append('a=setup:passive') + lines.append('a=mid:' + mid) + lines.append('a=rtcp-mux') + lines.append('a=recvonly') + + for codec_line in media.get('codec_lines', []): + lines.append(codec_line) + + lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host') + lines.append('a=end-of-candidates') + + return '\r\n'.join(lines) + '\r\n' + + +def build_whep_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int, ingest_offer : dict) -> str: + lines = [] + lines.append('v=0') + lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1') + lines.append('s=-') + lines.append('t=0 0') + + mids = [] + for i, media in enumerate(offer.get('media', [])): + mid = media.get('mid', str(i)) + mids.append(mid) + + if mids: + lines.append('a=group:BUNDLE ' + ' '.join(mids)) + + lines.append('a=ice-lite') + + for i, media in enumerate(offer.get('media', [])): + kind = media.get('kind') + formats = media.get('formats', []) + profile = media.get('profile', 'UDP/TLS/RTP/SAVPF') + mid = media.get('mid', str(i)) + lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats)) + lines.append('c=IN IP4 127.0.0.1') + lines.append('a=rtcp:9 IN IP4 0.0.0.0') + lines.append('a=ice-ufrag:' + local_ufrag) + lines.append('a=ice-pwd:' + local_pwd) + lines.append('a=ice-options:ice2') + lines.append('a=fingerprint:sha-256 ' + server_fingerprint) + lines.append('a=setup:passive') + lines.append('a=mid:' + mid) + lines.append('a=rtcp-mux') + lines.append('a=sendonly') + + for codec_line in media.get('codec_lines', []): + lines.append(codec_line) + + lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host') + lines.append('a=end-of-candidates') + + return '\r\n'.join(lines) + '\r\n' + + +def create_ssl_context() -> SSL.Context: + ctx = SSL.Context(SSL.DTLS_METHOD) + ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: True) + ctx.use_certificate(server_cert) + ctx.use_privatekey(server_key) + ctx.set_cipher_list(b'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-SHA') + ctx.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80') + return ctx + + +def create_session(stream_path : str) -> None: + ufrag, pwd = generate_ice_credentials() + sessions[stream_path] = { + 'ice_ufrag': ufrag, + 'ice_pwd': pwd, + 'ingest': None, + 'viewers': [], + 'ingest_offer': None, + 'rx_srtp': None, + 'tx_sessions': [] + } + + +def destroy_session(stream_path : str) -> None: + sessions.pop(stream_path, None) + + +def handle_whip(stream_path : str, sdp_offer : str) -> Optional[str]: + session = sessions.get(stream_path) + + if not session: + return None + + offer = parse_sdp_offer(sdp_offer) + session['ingest_offer'] = offer + local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT + answer = build_sdp_answer(offer, session.get('ice_ufrag'), session.get('ice_pwd'), local_port) + return answer + + +def handle_whep(stream_path : str, sdp_offer : str) -> Optional[str]: + session = sessions.get(stream_path) + + if not session: + return None + + offer = parse_sdp_offer(sdp_offer) + viewer_ufrag, viewer_pwd = generate_ice_credentials() + local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT + ingest_offer = session.get('ingest_offer', offer) + answer = build_whep_answer(offer, viewer_ufrag, viewer_pwd, local_port, ingest_offer) + session['viewers'].append({'offer': offer, 'ice_ufrag': viewer_ufrag, 'ice_pwd': viewer_pwd}) + return answer + + +def start() -> None: + global running, udp_socket, http_thread + + generate_credentials() + running = True + + udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + udp_socket.bind(('0.0.0.0', WHIP_PORT)) + udp_socket.settimeout(1.0) + + udp_thread_instance = threading.Thread(target = run_udp_loop, daemon = True) + udp_thread_instance.start() + + http_thread = threading.Thread(target = run_http_server, daemon = True) + http_thread.start() + logger.info('webrtc sfu started on port ' + str(WHIP_PORT), __name__) + + +def stop() -> None: + global running, udp_socket + + running = False + + if udp_socket: + udp_socket.close() + udp_socket = None + + +dtls_connections : Dict[tuple, dict] = {} + + +def run_udp_loop() -> None: + while running: + try: + data, addr = udp_socket.recvfrom(2048) + + if not data: + continue + + first_byte = data[0] + + if first_byte == 0 or first_byte == 1: + handle_stun(data, addr) + if first_byte > 19 and first_byte < 64: + handle_dtls(data, addr) + if first_byte > 127 and first_byte < 192: + handle_srtp(data, addr) + + except socket.timeout: + continue + except Exception: + if running: + continue + + +def handle_dtls(data : bytes, addr : tuple) -> None: + conn = dtls_connections.get(addr) + + if not conn: + ctx = create_ssl_context() + ssl_conn = SSL.Connection(ctx) + ssl_conn.set_accept_state() + conn = {'ssl': ssl_conn, 'encrypted': False, 'rx_srtp': None, 'tx_srtp': None} + dtls_connections[addr] = conn + + ssl_conn = conn.get('ssl') + ssl_conn.bio_write(data) + + try: + if not conn.get('encrypted'): + try: + ssl_conn.do_handshake() + conn['encrypted'] = True + setup_srtp_session(conn) + logger.info('dtls handshake complete with ' + str(addr), __name__) + except SSL.WantReadError: + pass + else: + try: + ssl_conn.recv(1500) + except SSL.ZeroReturnError: + pass + except SSL.Error: + pass + except Exception: + pass + + flush_dtls(ssl_conn, addr) + + +def flush_dtls(ssl_conn : SSL.Connection, addr : tuple) -> None: + try: + outdata = ssl_conn.bio_read(1500) + + if outdata: + udp_socket.sendto(outdata, addr) + except SSL.Error: + pass + + +def setup_srtp_session(conn : dict) -> None: + ssl_conn = conn.get('ssl') + ssl_conn.get_selected_srtp_profile() + key_len = 16 + salt_len = 14 + view = ssl_conn.export_keying_material(b'EXTRACTOR-dtls_srtp', 2 * (key_len + salt_len)) + server_key = view[key_len:2 * key_len] + view[2 * key_len + salt_len:] + client_key = view[:key_len] + view[2 * key_len:2 * key_len + salt_len] + + rx_policy = SrtpPolicy(key = client_key, ssrc_type = SrtpPolicy.SSRC_ANY_INBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80) + rx_policy.allow_repeat_tx = True + rx_policy.window_size = 1024 + conn['rx_srtp'] = SrtpSession(rx_policy) + + tx_policy = SrtpPolicy(key = server_key, ssrc_type = SrtpPolicy.SSRC_ANY_OUTBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80) + tx_policy.allow_repeat_tx = True + tx_policy.window_size = 1024 + conn['tx_srtp'] = SrtpSession(tx_policy) + + +def is_rtcp(data : bytes) -> bool: + if len(data) < 2: + return False + pt = data[1] & 0x7F + return 64 <= pt <= 95 + + +def handle_srtp(data : bytes, addr : tuple) -> None: + conn = dtls_connections.get(addr) + + if not conn or not conn.get('rx_srtp'): + return + + try: + if is_rtcp(data): + plain = conn.get('rx_srtp').unprotect_rtcp(data) + else: + plain = conn.get('rx_srtp').unprotect(data) + + forward_rtp(plain, addr) + except Exception: + pass + + +def forward_rtp(data : bytes, source_addr : tuple) -> None: + for other_addr, conn in dtls_connections.items(): + if other_addr == source_addr: + continue + + if not conn.get('tx_srtp'): + continue + + try: + if is_rtcp(data): + encrypted = conn.get('tx_srtp').protect_rtcp(data) + else: + encrypted = conn.get('tx_srtp').protect(data) + udp_socket.sendto(encrypted, other_addr) + except Exception: + pass + + +def handle_stun(data : bytes, addr : tuple) -> None: + if len(data) < 20: + return + + msg_type = struct.unpack('!H', data[0:2])[0] + + if msg_type != 0x0001: + return + + msg_length = struct.unpack('!H', data[2:4])[0] + transaction_id = data[8:20] + + username = None + offset = 20 + + while offset < 20 + msg_length: + if offset + 4 > len(data): + break + attr_type = struct.unpack('!H', data[offset:offset + 2])[0] + attr_length = struct.unpack('!H', data[offset + 2:offset + 4])[0] + attr_value = data[offset + 4:offset + 4 + attr_length] + + if attr_type == 0x0006: + username = attr_value.decode('utf-8', errors = 'ignore') + + padded = attr_length + (4 - attr_length % 4) % 4 + offset += 4 + padded + + if not username: + return + + local_ufrag = username.split(':')[0] if ':' in username else username + session_pwd = None + + for session in sessions.values(): + if session.get('ice_ufrag') == local_ufrag: + session_pwd = session.get('ice_pwd') + break + + for viewer in session.get('viewers', []): + if viewer.get('ice_ufrag') == local_ufrag: + session_pwd = viewer.get('ice_pwd') + break + + if session_pwd: + break + + if not session_pwd: + return + + response = build_stun_response(transaction_id, addr, session_pwd) + udp_socket.sendto(response, addr) + + +def build_stun_response(transaction_id : bytes, addr : tuple, password : str) -> bytes: + import hmac + import zlib + + magic_cookie = 0x2112A442 + magic_bytes = struct.pack('!I', magic_cookie) + + xport = addr[1] ^ (magic_cookie >> 16) + ip_int = struct.unpack('!I', socket.inet_aton(addr[0]))[0] + xip = struct.pack('!I', ip_int ^ magic_cookie) + xor_addr_value = struct.pack('!BBH', 0, 0x01, xport) + xip + xor_addr_attr = struct.pack('!HH', 0x0020, len(xor_addr_value)) + xor_addr_value + + attrs_before_integrity = xor_addr_attr + integrity_dummy_len = len(attrs_before_integrity) + 4 + 20 + header_for_hmac = struct.pack('!HH', 0x0101, integrity_dummy_len) + magic_bytes + transaction_id + key = password.encode('utf-8') + integrity = hmac.new(key, header_for_hmac + attrs_before_integrity, hashlib.sha1).digest() + integrity_attr = struct.pack('!HH', 0x0008, 20) + integrity + + attrs_before_fp = attrs_before_integrity + integrity_attr + fp_dummy_len = len(attrs_before_fp) + 4 + 4 + header_for_fp = struct.pack('!HH', 0x0101, fp_dummy_len) + magic_bytes + transaction_id + crc = zlib.crc32(header_for_fp + attrs_before_fp) ^ 0x5354554E + fingerprint_attr = struct.pack('!HHI', 0x8028, 4, crc & 0xFFFFFFFF) + + all_attrs = attrs_before_integrity + integrity_attr + fingerprint_attr + header = struct.pack('!HH', 0x0101, len(all_attrs)) + magic_bytes + transaction_id + return header + all_attrs + + +def run_http_server() -> None: + from http.server import HTTPServer, BaseHTTPRequestHandler + + class WhipWhepHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args) -> None: + pass + + def do_POST(self) -> None: + path = self.path + content_length = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(content_length).decode('utf-8') if content_length else '' + + if path.endswith('/whip'): + stream_path = path[1:].rsplit('/whip', 1)[0] + answer = handle_whip(stream_path, body) + + if answer: + self.send_response(201) + self.send_header('Content-Type', 'application/sdp') + self.send_header('Location', path) + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + self.wfile.write(answer.encode('utf-8')) + return + + self.send_response(404) + self.end_headers() + return + + if path.endswith('/whep'): + stream_path = path[1:].rsplit('/whep', 1)[0] + answer = handle_whep(stream_path, body) + + if answer: + self.send_response(201) + self.send_header('Content-Type', 'application/sdp') + self.send_header('Location', path) + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + self.wfile.write(answer.encode('utf-8')) + return + + self.send_response(404) + self.end_headers() + return + + self.send_response(404) + self.end_headers() + + def do_OPTIONS(self) -> None: + self.send_response(200) + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS') + self.send_header('Access-Control-Allow-Headers', 'Content-Type') + self.end_headers() + + server = HTTPServer(('0.0.0.0', WHIP_PORT), WhipWhepHandler) + server.timeout = 1 + + while running: + server.handle_request() diff --git a/facefusion/whip_relay.py b/facefusion/whip_relay.py new file mode 100644 index 00000000..230ba726 --- /dev/null +++ b/facefusion/whip_relay.py @@ -0,0 +1,99 @@ +import os +import shutil +import subprocess +import time +from typing import Optional + +import httpx + +from facefusion import logger + +RELAY_PORT : int = 8891 +RELAY_BINARY : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'tools', 'whip_relay') +RELAY_PROCESS : Optional[subprocess.Popen[bytes]] = None + + +def get_whip_url(stream_path : str) -> str: + return 'http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/whip' + + +def get_whep_url(stream_path : str) -> str: + return 'http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/whep' + + +def resolve_binary() -> str: + relay_path = shutil.which('whip_relay') + + if relay_path: + return relay_path + + if os.path.isfile(RELAY_BINARY): + return RELAY_BINARY + return RELAY_BINARY + + +def start() -> None: + global RELAY_PROCESS + + subprocess.run([ 'fuser', '-k', str(RELAY_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + time.sleep(0.5) + + relay_binary = resolve_binary() + + if not os.path.isfile(relay_binary): + logger.warn('whip_relay binary not found at ' + relay_binary + ', skipping', __name__) + return + + env = os.environ.copy() + env['LD_LIBRARY_PATH'] = '/home/henry/local/lib:' + env.get('LD_LIBRARY_PATH', '') + RELAY_PROCESS = subprocess.Popen( + [ relay_binary, str(RELAY_PORT) ], + env = env, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE + ) + logger.info('whip relay started on port ' + str(RELAY_PORT), __name__) + + +def stop() -> None: + global RELAY_PROCESS + + if RELAY_PROCESS: + RELAY_PROCESS.terminate() + RELAY_PROCESS.wait() + RELAY_PROCESS = None + + +def wait_for_ready() -> bool: + for _ in range(10): + try: + response = httpx.get('http://localhost:' + str(RELAY_PORT) + '/health', timeout = 1) + + if response.status_code == 200: + return True + except Exception: + pass + time.sleep(0.5) + return False + + +def is_session_ready(stream_path : str) -> bool: + try: + response = httpx.get('http://localhost:' + str(RELAY_PORT) + '/session/' + stream_path, timeout = 1) + + if response.status_code == 200: + return True + except Exception: + pass + return False + + +def create_session(stream_path : str) -> int: + try: + response = httpx.post('http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/create', timeout = 5) + + if response.status_code == 200: + return int(response.text) + except Exception: + pass + return 0 diff --git a/test_whip_stream.html b/test_stream.html similarity index 99% rename from test_whip_stream.html rename to test_stream.html index 3622deb0..653e3531 100644 --- a/test_whip_stream.html +++ b/test_stream.html @@ -72,11 +72,11 @@ .timeline { display: none; align-items: stretch; padding: 0; background: #0e0e14; border-top: 1px solid #1e1e2e; border-bottom: 1px solid #1e1e2e; flex-shrink: 0; } .timeline.visible { display: flex; } .timeline .transport { display: flex; align-items: center; gap: 2px; padding: 0 0.4rem; background: #12121a; border-right: 1px solid #1e1e2e; } - .timeline .transport-btn { width: 28px; height: 28px; border: none; border-radius: 6px; cursor: pointer; display: flex; align-items: center; justify-content: center; background: transparent; color: #888; transition: all 0.15s; } + .timeline .transport-btn { width: 36px; height: 36px; border: none; border-radius: 8px; cursor: pointer; display: flex; align-items: center; justify-content: center; background: transparent; color: #888; transition: all 0.15s; } .timeline .transport-btn:hover { background: #1e1e2e; color: #fff; } .timeline .transport-btn:disabled { opacity: 0.25; cursor: not-allowed; } - .timeline .transport-btn.active { color: #00b894; } - .timeline .transport-btn svg { width: 14px; height: 14px; fill: currentColor; } + .timeline .transport-btn.active { color: #888; } + .timeline .transport-btn svg { width: 18px; height: 18px; fill: currentColor; } .timeline .time { font-size: 0.75rem; color: #888; font-family: monospace; min-width: 60px; display: flex; align-items: center; justify-content: center; padding: 0 0.6rem; background: #12121a; border-right: 1px solid #1e1e2e; } .timeline .time:last-child { border-right: none; border-left: 1px solid #1e1e2e; } .timeline .track { flex: 1; position: relative; height: 2em; cursor: pointer; background: repeating-linear-gradient(90deg, transparent, transparent 59px, #1a1a25 59px, #1a1a25 60px); }