From 28ded002fc7a077642b4794ab04db3f1bc0d8d48 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 25 Mar 2026 21:05:53 +0100 Subject: [PATCH] shrink down to the release candidate --- COMPILE_GUIDE.md | 110 ----- e2e_video_modes.py | 209 ++------ facefusion/aiortc_bridge.py | 336 ------------- facefusion/apis/core.py | 47 +- facefusion/apis/endpoints/stream.py | 742 ++-------------------------- facefusion/apis/stream_helper.py | 194 +------- facefusion/mediamtx.py | 107 ---- facefusion/rtc.py | 339 +------------ facefusion/webrtc_sfu.py | 546 -------------------- facefusion/whip_relay.py | 62 --- mediamtx.yml | 9 - test_stream.html | 274 +--------- tests/test_api_stream.py | 37 -- tools/whip_relay | Bin 31280 -> 0 bytes tools/whip_relay.c | 619 ----------------------- 15 files changed, 124 insertions(+), 3507 deletions(-) delete mode 100644 COMPILE_GUIDE.md delete mode 100644 facefusion/aiortc_bridge.py delete mode 100644 facefusion/mediamtx.py delete mode 100644 facefusion/webrtc_sfu.py delete mode 100644 facefusion/whip_relay.py delete mode 100644 mediamtx.yml delete mode 100755 tools/whip_relay delete mode 100644 tools/whip_relay.c diff --git a/COMPILE_GUIDE.md b/COMPILE_GUIDE.md deleted file mode 100644 index 3cf64f9b..00000000 --- a/COMPILE_GUIDE.md +++ /dev/null @@ -1,110 +0,0 @@ -# Compiling libdatachannel - -Prebuilt DLLs from OBS or pip lack VP8 support. We compile from source to get all codecs (H264, VP8, AV1, Opus). - -## Source - -``` -git clone --depth 1 --recurse-submodules https://github.com/paullouisageneau/libdatachannel.git -cd libdatachannel -``` - -## Windows - -Requirements: Visual Studio Build Tools 2019+ with C++ workload, cmake, ninja (available via conda). - -```cmd -call "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 -cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF -cmake --build build --config Release -``` - -Output: `build/datachannel.dll` - -Rename to: `windows-x64-openssl-h264-vp8-av1-opus-datachannel-.dll` - -Place in: `bin/` - -## Linux - -Requirements: gcc/g++, cmake, ninja-build, libssl-dev. - -```bash -sudo apt install build-essential cmake ninja-build libssl-dev -cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF -cmake --build build --config Release -``` - -Output: `build/libdatachannel.so` - -Rename to: `linux-x64-openssl-h264-vp8-av1-opus-libdatachannel-.so` - -Install to: `/usr/local/lib/` or project `bin/` - -If installed to a custom path, run `sudo ldconfig` or set `LD_LIBRARY_PATH`. - -## macOS - -Requirements: Xcode Command Line Tools, cmake, ninja. - -```bash -xcode-select --install -brew install cmake ninja -cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF -cmake --build build --config Release -``` - -For universal binary (arm64 + x86_64): - -```bash -cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" -cmake --build build --config Release -``` - -Output: `build/libdatachannel.dylib` - -Rename to: `macos-universal-openssl-h264-vp8-av1-opus-libdatachannel-.dylib` - -Install to: `/usr/local/lib/` or project `bin/` - -## Naming convention - -``` -----datachannel-. -``` - -- os: `windows`, `linux`, `macos` -- arch: `x64`, `arm64`, `universal` -- tls: `openssl` (default), `gnutls`, `mbedtls` -- codecs: supported packetizers, e.g. `h264-vp8-av1-opus` -- version: libdatachannel version, e.g. `0.24.1` - -## Verifying the build - -```python -import ctypes -lib = ctypes.CDLL('path/to/datachannel.dll') -for fn in ['rtcSetH264Packetizer', 'rtcSetVP8Packetizer', 'rtcSetAV1Packetizer', 'rtcSetOpusPacketizer']: - try: - getattr(lib, fn) - print(f'{fn}: OK') - except AttributeError: - print(f'{fn}: MISSING') -``` - -## CMake flags reference - -| Flag | Default | Purpose | -|---|---|---| -| `NO_WEBSOCKET` | OFF | Disable WebSocket support (not needed) | -| `NO_MEDIA` | OFF | Disable media transport (must be OFF for codecs) | -| `NO_EXAMPLES` | OFF | Skip building examples | -| `NO_TESTS` | OFF | Skip building tests | -| `USE_NICE` | OFF | Use libnice instead of libjuice for ICE | -| `USE_GNUTLS` | OFF | Use GnuTLS instead of OpenSSL | -| `USE_MBEDTLS` | OFF | Use Mbed TLS instead of OpenSSL | - -## Runtime dependencies - -- **libopus**: Required for audio encoding. Install via `conda install -c conda-forge libopus` (Windows) or `apt install libopus-dev` (Linux) or `brew install opus` (macOS). -- **OpenSSL**: Usually bundled or system-provided. On Windows, conda provides it. diff --git a/e2e_video_modes.py b/e2e_video_modes.py index 75428706..a1e29986 100644 --- a/e2e_video_modes.py +++ b/e2e_video_modes.py @@ -2,7 +2,6 @@ import os import platform import signal import subprocess -import sys import time import httpx @@ -12,16 +11,8 @@ API_PORT : int = 8400 HTML_FILE : str = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_stream.html') SOURCE_FILE : str = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.assets', 'examples', 'source.jpg') -def is_windows() -> bool: - return platform.system().lower() == 'windows' - -def is_macos() -> bool: - return platform.system().lower() == 'darwin' - -if is_windows(): +if platform.system().lower() == 'windows': VIDEO_FILE : str = 'C:\\Users\\info\\Downloads\\face8k.mp4' -elif is_macos(): - VIDEO_FILE : str = '/Users/henry/Downloads/copy_face_instant.mp4' else: VIDEO_FILE : str = '/home/henry/Documents/examples/download.mp4' @@ -32,25 +23,12 @@ def safe_print(text : str) -> None: except UnicodeEncodeError: print(text.encode('ascii', errors='replace').decode('ascii')) -_ALL_MODES =\ -[ - 'whip-mediamtx', - 'whip-python', - 'whip-datachannel', - 'ws-fmp4', - 'datachannel-direct', - 'datachannel-relay-py', - 'ws-mjpeg' -] - -MODES = [ m for m in _ALL_MODES if not (is_macos() and m == 'whip-mediamtx') ] - def start_api() -> subprocess.Popen: env = os.environ.copy() - python_cmd = 'python' if is_windows() else 'python3' + python_cmd = 'python' if platform.system().lower() == 'windows' else 'python3' - if not is_windows() and not is_macos(): + if platform.system().lower() != 'windows': env['LD_LIBRARY_PATH'] = '/home/henry/local/lib:' + env.get('LD_LIBRARY_PATH', '') proc = subprocess.Popen( @@ -81,7 +59,7 @@ def wait_for_api(timeout : int = 60) -> bool: def stop_api(proc : subprocess.Popen) -> None: - if is_windows(): + if platform.system().lower() == 'windows': proc.terminate() else: proc.send_signal(signal.SIGTERM) @@ -95,64 +73,32 @@ def stop_api(proc : subprocess.Popen) -> None: time.sleep(1) -def kill_port_windows(port : int) -> None: - result = subprocess.run( - [ 'netstat', '-ano' ], - capture_output = True, text = True - ) - - for line in result.stdout.splitlines(): - if ':' + str(port) + ' ' in line and ('LISTENING' in line or 'ESTABLISHED' in line): - parts = line.split() - pid = parts[-1] - - if pid.isdigit() and int(pid) > 0: - subprocess.run([ 'taskkill', '/F', '/PID', pid ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - - -def kill_port_macos(port : int) -> None: - pids = set() - - for proto in [ 'tcp', 'udp' ]: - result = subprocess.run( - [ 'lsof', '-ti', proto + ':' + str(port) ], - capture_output = True, text = True - ) - - for pid in result.stdout.split(): - if pid.isdigit(): - pids.add(pid) - - for pid in pids: - subprocess.run([ 'kill', '-9', pid ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - - def kill_stale() -> None: - ports = [ API_PORT, 8889, 8189, 9997, 8890, 8891, 8892 ] + ports = [ API_PORT ] - if is_windows(): + if platform.system().lower() == 'windows': for port in ports: - kill_port_windows(port) - elif is_macos(): - for port in ports: - kill_port_macos(port) + result = subprocess.run([ 'netstat', '-ano' ], capture_output = True, text = True) + + for line in result.stdout.splitlines(): + if ':' + str(port) + ' ' in line and ('LISTENING' in line or 'ESTABLISHED' in line): + parts = line.split() + pid = parts[-1] + + if pid.isdigit() and int(pid) > 0: + subprocess.run([ 'taskkill', '/F', '/PID', pid ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) else: - subprocess.run([ 'fuser', '-k', str(API_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8889/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '9997/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8890/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8891/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8892/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + for port in ports: + subprocess.run([ 'fuser', '-k', str(port) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) time.sleep(2) -def test_mode(mode : str) -> dict: - result = {'mode': mode, 'session': False, 'source': False, 'video': False, 'ws_open': False, 'stream_ready': False, 'playback': False, 'error': None} +def test_rtc() -> dict: + result = {'session': False, 'source': False, 'video': False, 'ws_open': False, 'stream_ready': False, 'playback': False, 'error': None} print('\n' + '=' * 60) - print('TESTING: ' + mode) + print('TESTING: libdatachannel direct (RTC)') print('=' * 60) kill_stale() @@ -236,11 +182,7 @@ def test_mode(mode : str) -> dict: stop_api(api_proc) return result - print(' video OK, selecting mode: ' + mode) - page.select_option('#streamMode', mode) - time.sleep(0.5) - - print(' starting stream...') + print(' video OK, starting stream...') for _ in range(10): time.sleep(1) @@ -283,68 +225,17 @@ def test_mode(mode : str) -> dict: if 'stream ready' in log_text or 'WHEP' in log_text: result['stream_ready'] = True - if mode == 'ws-mjpeg': - result['stream_ready'] = True - - try: - has_img = page.evaluate('!!document.getElementById("outputVideo")._mjpegImg && !!document.getElementById("outputVideo")._mjpegImg.src') - - if has_img: - result['playback'] = True - print(' [' + str(i) + 's] MJPEG receiving frames') - break - except Exception: - pass - - if mode == 'ws-fmp4': - if 'MSE source buffer ready' in log_text: - result['stream_ready'] = True - - try: - mse_info = page.evaluate('''() => { - var v = document.getElementById("outputVideo"); - var ms = v._mediaSource || window.mediaSource; - var buf = (v.buffered && v.buffered.length > 0) ? v.buffered.end(0) : 0; - return { time: v.currentTime, buffered: buf, readyState: v.readyState, networkState: v.networkState }; - }''') - buffered = mse_info.get('buffered', 0) - - if buffered > 0 or mse_info.get('time', 0) > 0: - result['playback'] = True - print(' [' + str(i) + 's] MSE buffered=' + str(round(buffered, 2)) + ' time=' + str(round(mse_info.get('time', 0), 2))) - break - - if i % 5 == 0: - print(' [' + str(i) + 's] MSE: ' + str(mse_info)) - except Exception: - pass - else: - try: - frames_val = int(frames_stat) if frames_stat and frames_stat != '--' else 0 - except ValueError: - frames_val = 0 - - if frames_val > 0: - result['playback'] = True - print(' [' + str(i) + 's] frames=' + str(frames_val) + ' fps=' + fps_stat + ' rtc=' + rtc_stat) - break - try: - rtc_stats = page.evaluate('''() => { - if (!window.pc) return ''; - return pc.getStats().then(stats => { - var r = ''; - stats.forEach(report => { - if (report.type === 'inbound-rtp' && report.kind === 'video') { - r = 'pkts=' + (report.packetsReceived||0) + ' bytes=' + (report.bytesReceived||0) + ' lost=' + (report.packetsLost||0) + ' dropped=' + (report.framesDropped||0) + ' dec=' + (report.decoderImplementation||'?') + ' kf=' + (report.keyFramesDecoded||0) + ' nacks=' + (report.nackCount||0) + ' plis=' + (report.pliCount||0); - } - }); - return r; - }); - }''') - except Exception: - rtc_stats = '' - print(' [' + str(i) + 's] ws=' + ws_stat + ' rtc=' + rtc_stat + ' frames=' + frames_stat + ' ' + str(rtc_stats)) + frames_val = int(frames_stat) if frames_stat and frames_stat != '--' else 0 + except ValueError: + frames_val = 0 + + if frames_val > 0: + result['playback'] = True + print(' [' + str(i) + 's] frames=' + str(frames_val) + ' fps=' + fps_stat + ' rtc=' + rtc_stat) + break + + print(' [' + str(i) + 's] ws=' + ws_stat + ' rtc=' + rtc_stat + ' frames=' + frames_stat) if not result.get('playback'): log_text = page.locator('#log').text_content() @@ -376,36 +267,26 @@ def test_mode(mode : str) -> dict: def main() -> None: - modes_to_test = MODES - - if len(sys.argv) > 1: - modes_to_test = sys.argv[1:] - - results = [] - - for mode in modes_to_test: - result = test_mode(mode) - results.append(result) + result = test_rtc() print('\n\n' + '=' * 60) - print('SUMMARY') + print('RESULT') print('=' * 60) - for r in results: - status = 'PASS' if r.get('playback') else 'FAIL' - error = ' (' + r.get('error', '') + ')' if r.get('error') else '' - flags = [] + status = 'PASS' if result.get('playback') else 'FAIL' + error = ' (' + result.get('error', '') + ')' if result.get('error') else '' + flags = [] - if r.get('session'): - flags.append('session') - if r.get('ws_open'): - flags.append('ws') - if r.get('stream_ready'): - flags.append('ready') - if r.get('playback'): - flags.append('playback') + if result.get('session'): + flags.append('session') + if result.get('ws_open'): + flags.append('ws') + if result.get('stream_ready'): + flags.append('ready') + if result.get('playback'): + flags.append('playback') - print(' ' + status + ' ' + r.get('mode') + ' [' + ','.join(flags) + ']' + error) + print(' ' + status + ' datachannel-direct [' + ','.join(flags) + ']' + error) if __name__ == '__main__': diff --git a/facefusion/aiortc_bridge.py b/facefusion/aiortc_bridge.py deleted file mode 100644 index bc422c86..00000000 --- a/facefusion/aiortc_bridge.py +++ /dev/null @@ -1,336 +0,0 @@ -import asyncio -import threading -from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Optional - -import numpy -from aiortc import RTCPeerConnection, RTCSessionDescription, AudioStreamTrack, VideoStreamTrack -from av import AudioFrame, VideoFrame - -from facefusion import logger -from facefusion.types import VisionFrame - -BRIDGE_PORT_START : int = 8893 -AUDIO_SAMPLE_RATE : int = 48000 - - -class FramePushTrack(VideoStreamTrack): - kind = 'video' - - def __init__(self) -> None: - super().__init__() - self._frame : Optional[VisionFrame] = None - self._lock = threading.Lock() - self._started = False - - def push(self, vision_frame : VisionFrame) -> None: - with self._lock: - self._frame = vision_frame - - async def recv(self) -> VideoFrame: - pts, time_base = await self.next_timestamp() - - with self._lock: - frame_data = self._frame - - if frame_data is None: - frame_data = numpy.zeros((240, 320, 3), dtype = numpy.uint8) - - if not self._started: - self._started = True - logger.info('aiortc track sending first frame', __name__) - - video_frame = VideoFrame.from_ndarray(frame_data, format = 'bgr24') - video_frame.pts = pts - video_frame.time_base = time_base - return video_frame - - -class AudioPushTrack(AudioStreamTrack): - kind = 'audio' - - def __init__(self) -> None: - super().__init__() - self._buffer = bytearray() - self._lock = threading.Lock() - self._pts = 0 - self._frame_samples = 960 - - def push(self, pcm_data : bytes) -> None: - with self._lock: - self._buffer.extend(pcm_data) - - if len(self._buffer) > AUDIO_SAMPLE_RATE * 4: - self._buffer = self._buffer[-AUDIO_SAMPLE_RATE * 4:] - - async def recv(self) -> AudioFrame: - await asyncio.sleep(self._frame_samples / AUDIO_SAMPLE_RATE) - needed = self._frame_samples * 2 * 2 - - with self._lock: - if len(self._buffer) >= needed: - chunk = bytes(self._buffer[:needed]) - del self._buffer[:needed] - else: - chunk = None - - if chunk: - pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1) - else: - pcm = numpy.zeros((1, self._frame_samples * 2), dtype = numpy.int16) - - audio_frame = AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo') - audio_frame.sample_rate = AUDIO_SAMPLE_RATE - audio_frame.pts = self._pts - self._pts += self._frame_samples - return audio_frame - - -class AiortcBridge: - def __init__(self) -> None: - global BRIDGE_PORT_START - self.port = BRIDGE_PORT_START - BRIDGE_PORT_START += 1 - self.video_track = FramePushTrack() - self.audio_track = AudioPushTrack() - self.pcs : list = [] - self._http_thread : Optional[threading.Thread] = None - self._running = False - self._has_viewer = False - self._loop = None - - async def start(self) -> None: - self._running = True - self._loop = asyncio.get_event_loop() - self._http_thread = threading.Thread(target = self._run_http, daemon = True) - self._http_thread.start() - logger.info('aiortc bridge started on port ' + str(self.port), __name__) - - async def stop(self) -> None: - self._running = False - - for pc in self.pcs: - try: - loop = asyncio.get_event_loop() - asyncio.run_coroutine_threadsafe(pc.close(), loop) - except Exception: - pass - - def push_frame(self, vision_frame : VisionFrame) -> None: - self.video_track.push(vision_frame) - - def push_audio(self, audio_data : bytes) -> None: - self.audio_track.push(audio_data) - - def has_viewer(self) -> bool: - return self._has_viewer - - def _handle_whep(self, sdp_offer : str) -> Optional[str]: - if not self._loop: - return None - - future = asyncio.run_coroutine_threadsafe(self._create_pc(sdp_offer), self._loop) - - try: - return future.result(timeout = 10) - except Exception as exception: - logger.error('whep error: ' + str(exception), __name__) - return None - - async def _create_pc(self, sdp_offer : str) -> Optional[str]: - pc = RTCPeerConnection() - self.pcs.append(pc) - pc.addTrack(self.video_track) - pc.addTrack(self.audio_track) - - offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer') - await pc.setRemoteDescription(offer) - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - self._has_viewer = True - return pc.localDescription.sdp - - def _run_http(self) -> None: - bridge = self - - class WhepHandler(BaseHTTPRequestHandler): - def log_message(self, format, *args) -> None: - pass - - def do_OPTIONS(self) -> None: - self.send_response(200) - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS') - self.send_header('Access-Control-Allow-Headers', 'Content-Type') - self.end_headers() - - def do_POST(self) -> None: - content_length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(content_length).decode('utf-8') if content_length else '' - answer = bridge._handle_whep(body) - - if answer: - self.send_response(201) - self.send_header('Content-Type', 'application/sdp') - self.send_header('Access-Control-Allow-Origin', '*') - self.end_headers() - self.wfile.write(answer.encode('utf-8')) - else: - self.send_response(500) - self.send_header('Access-Control-Allow-Origin', '*') - self.end_headers() - - server = HTTPServer(('0.0.0.0', self.port), WhepHandler) - server.timeout = 1 - - while self._running: - server.handle_request() - - -class WhipAiortcBridge: - def __init__(self) -> None: - global BRIDGE_PORT_START - self.port = BRIDGE_PORT_START - BRIDGE_PORT_START += 1 - self.whip_port = BRIDGE_PORT_START - BRIDGE_PORT_START += 1 - self._ingest_pc = None - self._relay_track = None - self._viewer_pcs : list = [] - self._http_thread : Optional[threading.Thread] = None - self._running = False - self._loop = None - self._ingest_ready = False - - async def start(self) -> None: - self._running = True - self._loop = asyncio.get_event_loop() - self._http_thread = threading.Thread(target = self._run_http, daemon = True) - self._http_thread.start() - logger.info('whip-aiortc bridge whip=' + str(self.whip_port) + ' whep=' + str(self.port), __name__) - - async def stop(self) -> None: - self._running = False - - if self._ingest_pc: - await self._ingest_pc.close() - - for pc in self._viewer_pcs: - await pc.close() - - def get_whip_url(self) -> str: - return 'http://localhost:' + str(self.whip_port) + '/whip' - - def get_whep_url(self) -> str: - return 'http://localhost:' + str(self.port) + '/whep' - - def is_ready(self) -> bool: - return self._ingest_ready - - def _handle_whip(self, sdp_offer : str) -> Optional[str]: - if not self._loop: - return None - - future = asyncio.run_coroutine_threadsafe(self._create_ingest(sdp_offer), self._loop) - - try: - return future.result(timeout = 10) - except Exception as exception: - logger.error('whip ingest error: ' + str(exception), __name__) - return None - - def _handle_whep(self, sdp_offer : str) -> Optional[str]: - if not self._loop: - return None - - future = asyncio.run_coroutine_threadsafe(self._create_viewer(sdp_offer), self._loop) - - try: - return future.result(timeout = 10) - except Exception as exception: - logger.error('whep error: ' + str(exception), __name__) - return None - - async def _create_ingest(self, sdp_offer : str) -> Optional[str]: - from aiortc import MediaStreamTrack - from aiortc.contrib.media import MediaRelay - - pc = RTCPeerConnection() - self._ingest_pc = pc - self._relay = MediaRelay() - - @pc.on('track') - def on_track(track : MediaStreamTrack) -> None: - if track.kind == 'video': - self._relay_track = self._relay.subscribe(track) - self._ingest_ready = True - logger.info('whip ingest video track received', __name__) - - offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer') - await pc.setRemoteDescription(offer) - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - return pc.localDescription.sdp - - async def _create_viewer(self, sdp_offer : str) -> Optional[str]: - pc = RTCPeerConnection() - self._viewer_pcs.append(pc) - - if self._relay_track: - pc.addTrack(self._relay_track) - - offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer') - await pc.setRemoteDescription(offer) - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - return pc.localDescription.sdp - - def _run_http(self) -> None: - bridge = self - - class Handler(BaseHTTPRequestHandler): - def log_message(self, format, *args) -> None: - pass - - def do_OPTIONS(self) -> None: - self.send_response(200) - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, DELETE') - self.send_header('Access-Control-Allow-Headers', 'Content-Type, Authorization') - self.end_headers() - - def do_POST(self) -> None: - content_length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(content_length).decode('utf-8') if content_length else '' - path = self.path - - if '/whip' in path: - answer = bridge._handle_whip(body) - elif '/whep' in path: - answer = bridge._handle_whep(body) - else: - self.send_response(404) - self.end_headers() - return - - if answer: - self.send_response(201) - self.send_header('Content-Type', 'application/sdp') - self.send_header('Location', path) - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Expose-Headers', 'Location') - self.end_headers() - self.wfile.write(answer.encode('utf-8')) - else: - self.send_response(500) - self.send_header('Access-Control-Allow-Origin', '*') - self.end_headers() - - whip_server = HTTPServer(('0.0.0.0', self.whip_port), Handler) - whip_server.timeout = 0.5 - whep_server = HTTPServer(('0.0.0.0', self.port), Handler) - whep_server.timeout = 0.5 - - while self._running: - whip_server.handle_request() - whep_server.handle_request() diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index c5dcaf9b..d5d5b7b2 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -6,37 +6,19 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route, WebSocketRoute -from facefusion import logger, mediamtx +from facefusion import logger from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset from facefusion.apis.endpoints.capabilities import get_capabilities from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state -from facefusion.common_helper import is_linux, is_windows -from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_audio, websocket_stream_live, websocket_stream_mjpeg, websocket_stream_rtc, websocket_stream_rtc_relay, websocket_stream_whip, websocket_stream_whip_dc, websocket_stream_whip_py +from facefusion.apis.endpoints.stream import post_whep, websocket_stream, websocket_stream_rtc from facefusion.apis.middlewares.session import create_session_guard @asynccontextmanager async def lifespan(app : Starlette) -> AsyncGenerator[None, None]: - if is_linux(): - mediamtx.start() - mediamtx.wait_for_ready() - - try: - from facefusion import webrtc_sfu - webrtc_sfu.start() - except Exception as exception: - logger.warn('webrtc sfu: ' + str(exception), __name__) - - try: - from facefusion import whip_relay - whip_relay.start() - whip_relay.wait_for_ready() - except Exception as exception: - logger.warn('whip relay: ' + str(exception), __name__) - try: from facefusion import rtc rtc.start() @@ -45,21 +27,6 @@ async def lifespan(app : Starlette) -> AsyncGenerator[None, None]: yield - if is_linux(): - mediamtx.stop() - - try: - from facefusion import webrtc_sfu - webrtc_sfu.stop() - except Exception: - pass - - try: - from facefusion import whip_relay - whip_relay.stop() - except Exception: - pass - try: from facefusion import rtc rtc.stop() @@ -85,15 +52,9 @@ def create_api() -> Starlette: Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]), WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]), + Route('/stream/{session_id}/whep', post_whep, methods = [ 'POST' ]), WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]), - WebSocketRoute('/stream/whip', websocket_stream_whip, middleware = [ session_guard ]), - WebSocketRoute('/stream/whip-py', websocket_stream_whip_py, middleware = [ session_guard ]), - WebSocketRoute('/stream/whip-dc', websocket_stream_whip_dc, middleware = [ session_guard ]), - WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ]), - WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ]), - WebSocketRoute('/stream/rtc-relay', websocket_stream_rtc_relay, middleware = [ session_guard ]), - WebSocketRoute('/stream/mjpeg', websocket_stream_mjpeg, middleware = [ session_guard ]), - WebSocketRoute('/stream/audio', websocket_stream_audio, middleware = [ session_guard ]) + WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ]) ] api = Starlette(routes = routes, lifespan = lifespan) diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 1ae0610e..20729dec 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -1,25 +1,25 @@ import asyncio -import os as _os +import os import threading import time from collections import deque from concurrent.futures import ThreadPoolExecutor -from typing import Deque, List +from typing import List import cv2 import numpy +from starlette.requests import Request +from starlette.responses import Response from starlette.websockets import WebSocket from facefusion import logger, session_context, session_manager, state_manager -from facefusion.common_helper import is_windows -from facefusion.apis.stream_helper import STREAM_AUDIO_RATE from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token -from facefusion import mediamtx -from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_fmp4_encoder, close_whip_encoder, collect_fmp4_chunks, create_fmp4_encoder, create_vp8_pipe_encoder, create_whip_encoder, feed_whip_audio, feed_whip_frame, process_stream_frame, read_fmp4_output +from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, create_vp8_pipe_encoder, feed_whip_frame, process_stream_frame from facefusion.streamer import process_vision_frame from facefusion.types import VisionFrame + JPEG_MAGIC : bytes = b'\xff\xd8' @@ -52,454 +52,16 @@ async def websocket_stream(websocket : WebSocket) -> None: await websocket.close() -def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, audio_write_fd_holder : list, stream_path : str) -> None: - encoder = None - audio_write_fd = -1 - output_deque : Deque[VisionFrame] = deque() - - with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: - futures = [] - - while not stop_event.is_set(): - with lock: - capture_frame = latest_frame_holder[0] - latest_frame_holder[0] = None - - if capture_frame is not None: - future = executor.submit(process_stream_frame, capture_frame) - futures.append(future) - - for future_done in [ future for future in futures if future.done() ]: - output_deque.append(future_done.result()) - futures.remove(future_done) - - if encoder and encoder.poll() is not None: - stderr_output = encoder.stderr.read() if encoder.stderr else b'' - logger.error('encoder died with code ' + str(encoder.returncode) + ': ' + stderr_output.decode(), __name__) - break - - while output_deque: - temp_vision_frame = output_deque.popleft() - - if not encoder: - height, width = temp_vision_frame.shape[:2] - whip_url = mediamtx.get_whip_url(stream_path) - encoder, audio_write_fd = create_whip_encoder(width, height, STREAM_FPS, STREAM_QUALITY, whip_url) - audio_write_fd_holder[0] = audio_write_fd - logger.info('whip encoder started ' + str(width) + 'x' + str(height), __name__) - - feed_whip_frame(encoder, temp_vision_frame) - - if encoder and not ready_event.is_set() and mediamtx.is_path_ready(stream_path): - ready_event.set() - - if capture_frame is None and not output_deque: - time.sleep(0.005) - - if encoder: - close_whip_encoder(encoder, audio_write_fd) - - -async def websocket_stream_whip(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - stream_path = 'stream/' + session_id - mediamtx.remove_path(stream_path) - mediamtx.add_path(stream_path) - logger.info('mediamtx path added ' + stream_path, __name__) - - latest_frame_holder : list = [None] - audio_write_fd_holder : list = [-1] - whep_sent = False - lock = threading.Lock() - stop_event = threading.Event() - ready_event = threading.Event() - worker = threading.Thread(target = run_whip_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, audio_write_fd_holder, stream_path), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - if not whep_sent and ready_event.is_set(): - whep_url = mediamtx.get_whep_url(stream_path) - await websocket.send_text(whep_url) - whep_sent = True - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - if data[:2] != JPEG_MAGIC and audio_write_fd_holder[0] > 0: - feed_whip_audio(audio_write_fd_holder[0], data) - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - mediamtx.remove_path(stream_path) - return - - await websocket.close() - - -def run_audio_silence_feeder(audio_write_fd_holder : list, stop_event : threading.Event, audio_active_event : threading.Event) -> None: - frame_bytes = STREAM_AUDIO_RATE // 50 * 2 * 2 - silence = b'\x00' * frame_bytes - - while not stop_event.is_set(): - if not audio_active_event.is_set(): - fd = audio_write_fd_holder[0] - - if fd > 0: - try: - _os.write(fd, silence) - except OSError: - break - - time.sleep(0.02) - - -def run_fmp4_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, output_chunks : List[bytes], output_lock : threading.Lock, audio_write_fd_holder : list, audio_active_event : threading.Event) -> None: - encoder = None - audio_write_fd = -1 - reader_thread = None - output_deque : Deque[VisionFrame] = deque() - - with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: - futures = [] - - while not stop_event.is_set(): - with lock: - capture_frame = latest_frame_holder[0] - latest_frame_holder[0] = None - - if capture_frame is not None: - future = executor.submit(process_stream_frame, capture_frame) - futures.append(future) - - for future_done in [ future for future in futures if future.done() ]: - output_deque.append(future_done.result()) - futures.remove(future_done) - - if encoder and encoder.poll() is not None: - stderr_output = encoder.stderr.read() if encoder.stderr else b'' - logger.error('fmp4 encoder died with code ' + str(encoder.returncode) + ': ' + stderr_output.decode(), __name__) - break - - while output_deque: - temp_vision_frame = output_deque.popleft() - - if not encoder: - height, width = temp_vision_frame.shape[:2] - encoder, audio_write_fd = create_fmp4_encoder(width, height, STREAM_FPS, STREAM_QUALITY) - audio_write_fd_holder[0] = audio_write_fd - reader_thread = threading.Thread(target = read_fmp4_output, args = (encoder, output_chunks, output_lock), daemon = True) - reader_thread.start() - silence_thread = threading.Thread(target = run_audio_silence_feeder, args = (audio_write_fd_holder, stop_event, audio_active_event), daemon = True) - silence_thread.start() - logger.info('fmp4 encoder started ' + str(width) + 'x' + str(height), __name__) - - feed_whip_frame(encoder, temp_vision_frame) - - if capture_frame is None and not output_deque: - time.sleep(0.005) - - if encoder: - close_fmp4_encoder(encoder, audio_write_fd) - - -async def websocket_stream_live(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - latest_frame_holder : list = [None] - audio_write_fd_holder : list = [-1] - output_chunks : List[bytes] = [] - lock = threading.Lock() - output_lock = threading.Lock() - stop_event = threading.Event() - audio_active_event = threading.Event() - worker = threading.Thread(target = run_fmp4_pipeline, args = (latest_frame_holder, lock, stop_event, output_chunks, output_lock, audio_write_fd_holder, audio_active_event), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - chunks = collect_fmp4_chunks(output_chunks, output_lock) - - if chunks: - await websocket.send_bytes(chunks) - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - if data[:2] != JPEG_MAGIC and audio_write_fd_holder[0] > 0: - audio_active_event.set() - feed_whip_audio(audio_write_fd_holder[0], data) - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - return - - await websocket.close() - - -def run_mjpeg_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, output_holder : list, output_lock : threading.Lock) -> None: - output_deque : Deque[VisionFrame] = deque() - - with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: - futures = [] - - while not stop_event.is_set(): - with lock: - capture_frame = latest_frame_holder[0] - latest_frame_holder[0] = None - - if capture_frame is not None: - future = executor.submit(process_stream_frame, capture_frame) - futures.append(future) - - for future_done in [ future for future in futures if future.done() ]: - output_deque.append(future_done.result()) - futures.remove(future_done) - - while output_deque: - temp_vision_frame = output_deque.popleft() - is_success, encoded = cv2.imencode('.jpg', temp_vision_frame, [cv2.IMWRITE_JPEG_QUALITY, 92]) - - if is_success: - with output_lock: - output_holder[0] = encoded.tobytes() - - if capture_frame is None and not output_deque: - time.sleep(0.005) - - -async def websocket_stream_mjpeg(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - latest_frame_holder : list = [None] - output_holder : list = [None] - lock = threading.Lock() - output_lock = threading.Lock() - stop_event = threading.Event() - worker = threading.Thread(target = run_mjpeg_pipeline, args = (latest_frame_holder, lock, stop_event, output_holder, output_lock), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - with output_lock: - jpeg_data = output_holder[0] - output_holder[0] = None - - if jpeg_data: - await websocket.send_bytes(jpeg_data) - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - return - - await websocket.close() - - -async def websocket_stream_audio(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - - await websocket.accept(subprotocol = subprotocol) - - try: - while True: - message = await websocket.receive() - - if message.get('bytes'): - await websocket.send_bytes(message.get('bytes')) - except Exception: - pass - - -async def websocket_stream_whip_py(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - from facefusion.aiortc_bridge import AiortcBridge - - bridge = AiortcBridge() - await bridge.start() - whep_url = 'http://localhost:' + str(bridge.port) + '/whep' - - latest_frame_holder : list = [None] - whep_sent = False - lock = threading.Lock() - stop_event = threading.Event() - ready_event = threading.Event() - worker = threading.Thread(target = run_aiortc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, bridge), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - if not whep_sent and ready_event.is_set(): - await websocket.send_text(whep_url) - whep_sent = True - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - if data[:2] != JPEG_MAGIC: - bridge.push_audio(data) - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - await bridge.stop() - return - - await websocket.close() - - -def run_aiortc_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, bridge : object) -> None: - output_deque : Deque[VisionFrame] = deque() - - with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: - futures = [] - - while not stop_event.is_set(): - with lock: - capture_frame = latest_frame_holder[0] - latest_frame_holder[0] = None - - if capture_frame is not None: - future = executor.submit(process_stream_frame, capture_frame) - futures.append(future) - - for future_done in [ future for future in futures if future.done() ]: - output_deque.append(future_done.result()) - futures.remove(future_done) - - while output_deque: - temp_vision_frame = output_deque.popleft() - bridge.push_frame(temp_vision_frame) - - if not ready_event.is_set(): - time.sleep(2) - ready_event.set() - - if capture_frame is None and not output_deque: - time.sleep(0.005) - - -def read_h264_output(process, h264_chunks : List[bytes], h264_lock : threading.Lock) -> None: - fd = process.stdout.fileno() - - if not is_windows(): - import fcntl - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~_os.O_NONBLOCK) - - while True: - chunk = _os.read(fd, 4096) - - if not chunk: - break - - with h264_lock: - h264_chunks.append(chunk) - - def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lock) -> None: - fd = process.stdout.fileno() + pipe_handle = process.stdout.fileno() - if not is_windows(): - import fcntl - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~_os.O_NONBLOCK) + if os.name != 'nt': + os.set_blocking(pipe_handle, True) header = b'' while len(header) < 32: - chunk = _os.read(fd, 32 - len(header)) + chunk = os.read(pipe_handle, 32 - len(header)) if not chunk: return @@ -510,7 +72,7 @@ def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lo frame_header = b'' while len(frame_header) < 12: - chunk = _os.read(fd, 12 - len(frame_header)) + chunk = os.read(pipe_handle, 12 - len(frame_header)) if not chunk: return @@ -521,7 +83,7 @@ def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lo frame_data = b'' while len(frame_data) < frame_size: - chunk = _os.read(fd, frame_size - len(frame_data)) + chunk = os.read(pipe_handle, frame_size - len(frame_data)) if not chunk: return @@ -532,203 +94,13 @@ def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lo frame_list.append(frame_data) -def run_h264_dc_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, backend : str, stream_path : str, rtp_port : int) -> None: - encoder = None - reader_thread = None - vp8_frames : List[bytes] = [] - vp8_lock = threading.Lock() - output_deque : Deque[VisionFrame] = deque() - udp_sock = None - - if backend == 'relay': - import socket as sock - udp_sock = sock.socket(sock.AF_INET, sock.SOCK_DGRAM) - - with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: - futures = [] - - while not stop_event.is_set(): - with lock: - capture_frame = latest_frame_holder[0] - latest_frame_holder[0] = None - - if capture_frame is not None: - future = executor.submit(process_stream_frame, capture_frame) - futures.append(future) - - for future_done in [ future for future in futures if future.done() ]: - output_deque.append(future_done.result()) - futures.remove(future_done) - - if encoder and encoder.poll() is not None: - stderr_output = encoder.stderr.read() if encoder.stderr else b'' - logger.error('vp8 encoder died: ' + stderr_output.decode(), __name__) - break - - while output_deque: - temp_vision_frame = output_deque.popleft() - - if not encoder: - height, width = temp_vision_frame.shape[:2] - encoder = create_vp8_pipe_encoder(width, height, STREAM_FPS, STREAM_QUALITY) - reader_thread = threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True) - reader_thread.start() - logger.info('vp8 encoder started ' + str(width) + 'x' + str(height) + ' [' + backend + ']', __name__) - - feed_whip_frame(encoder, temp_vision_frame) - - with vp8_lock: - if vp8_frames: - pending = list(vp8_frames) - vp8_frames.clear() - - for frame in pending: - if backend == 'relay' and udp_sock: - if len(frame) <= 64999: - udp_sock.sendto(b'\x01' + frame, ('127.0.0.1', rtp_port)) - if backend == 'rtc': - from facefusion import rtc - rtc.send_vp8_frame(stream_path, frame) - - if not ready_event.is_set() and encoder and encoder.poll() is None: - time.sleep(1) - ready_event.set() - - if capture_frame is None and not output_deque: - time.sleep(0.005) - - if encoder: - encoder.stdin.close() - encoder.terminate() - encoder.wait(timeout = 5) - - if udp_sock: - udp_sock.close() - - -async def websocket_stream_whip_dc(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - from facefusion.aiortc_bridge import AiortcBridge - - bridge = AiortcBridge() - await bridge.start() - whep_url = 'http://localhost:' + str(bridge.port) + '/whep' - - latest_frame_holder : list = [None] - whep_sent = False - lock = threading.Lock() - stop_event = threading.Event() - ready_event = threading.Event() - worker = threading.Thread(target = run_aiortc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, bridge), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - if not whep_sent and ready_event.is_set(): - await websocket.send_text(whep_url) - whep_sent = True - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - if data[:2] != JPEG_MAGIC: - bridge.push_audio(data) - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - await bridge.stop() - return - - await websocket.close() - - -async def websocket_stream_whip_aio(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - from facefusion.aiortc_bridge import AiortcBridge - - bridge = AiortcBridge() - await bridge.start() - whep_url = 'http://localhost:' + str(bridge.port) + '/whep' - - latest_frame_holder : list = [None] - whep_sent = False - lock = threading.Lock() - stop_event = threading.Event() - ready_event = threading.Event() - worker = threading.Thread(target = run_aiortc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, bridge), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - if not whep_sent and ready_event.is_set(): - await websocket.send_text(whep_url) - whep_sent = True - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - if data[:2] != JPEG_MAGIC: - bridge.push_audio(data) - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - await bridge.stop() - return - - await websocket.close() - - def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, stream_path : str) -> None: from facefusion import rtc + encoder = None - reader_thread = None vp8_frames : List[bytes] = [] vp8_lock = threading.Lock() - output_deque : Deque[VisionFrame] = deque() + output_deque : deque = deque() with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: futures = [] @@ -764,8 +136,7 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s if not encoder: height, width = temp_vision_frame.shape[:2] encoder = create_vp8_pipe_encoder(width, height, STREAM_FPS, STREAM_QUALITY) - reader_thread = threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True) - reader_thread.start() + threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True).start() logger.info('vp8 direct encoder started ' + str(width) + 'x' + str(height), __name__) feed_whip_frame(encoder, temp_vision_frame) @@ -791,71 +162,6 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s encoder.wait(timeout = 5) -async def websocket_stream_rtc_relay(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') - - logger.info('rtc-relay: session_id=' + str(session_id) + ' source_paths=' + str(bool(source_paths)), __name__) - - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - from facefusion import rtc - - if not rtc.lib: - logger.error('rtc-relay: libdatachannel not loaded', __name__) - await websocket.close() - return - - stream_path = 'stream/' + session_id - rtc.create_session(stream_path) - whep_url = 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep' - - latest_frame_holder : list = [None] - whep_sent = False - lock = threading.Lock() - stop_event = threading.Event() - ready_event = threading.Event() - worker = threading.Thread(target = run_rtc_direct_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, stream_path), daemon = True) - worker.start() - - try: - while True: - message = await websocket.receive() - - if not whep_sent and ready_event.is_set(): - await websocket.send_text(whep_url) - whep_sent = True - - if message.get('bytes'): - data = message.get('bytes') - - if data[:2] == JPEG_MAGIC: - frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(frame): - with lock: - latest_frame_holder[0] = frame - - if data[:2] != JPEG_MAGIC: - rtc.send_audio(stream_path, data) - - except Exception as exception: - logger.error(str(exception), __name__) - - stop_event.set() - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, worker.join, 10) - rtc.destroy_session(stream_path) - return - - await websocket.close() - - async def websocket_stream_rtc(websocket : WebSocket) -> None: subprotocol = get_sec_websocket_protocol(websocket.scope) access_token = extract_access_token(websocket.scope) @@ -868,9 +174,10 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None: if source_paths: from facefusion import rtc + stream_path = 'stream/' + session_id rtc.create_session(stream_path) - whep_url = 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep' + whep_url = '/' + stream_path + '/whep' latest_frame_holder : list = [None] whep_sent = False @@ -911,3 +218,18 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None: return await websocket.close() + + +async def post_whep(request : Request) -> Response: + from facefusion import rtc + + session_id = request.path_params.get('session_id') + stream_path = 'stream/' + session_id + body = await request.body() + sdp_offer = body.decode('utf-8') + loop = asyncio.get_running_loop() + answer = await loop.run_in_executor(None, rtc.handle_whep_offer, stream_path, sdp_offer) + + if answer: + return Response(answer, status_code = 201, media_type = 'application/sdp') + return Response(status_code = 404) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index c1a110b4..21c01b72 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,21 +1,13 @@ -import os import subprocess -import tempfile -import threading -from typing import List, Optional, Tuple import cv2 from facefusion import ffmpeg_builder -from facefusion.common_helper import is_windows from facefusion.streamer import process_vision_frame from facefusion.types import VisionFrame STREAM_FPS : int = 30 STREAM_QUALITY : int = 80 -STREAM_AUDIO_RATE : int = 48000 -DTLS_CERT_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_cert.pem') -DTLS_KEY_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_key.pem') def compute_bitrate(width : int, height : int) -> str: @@ -46,186 +38,6 @@ def compute_bufsize(width : int, height : int) -> str: return '10000k' -def create_dtls_certificate() -> None: - if os.path.isfile(DTLS_CERT_FILE) and os.path.isfile(DTLS_KEY_FILE): - return - - subprocess.run([ - 'openssl', 'req', '-x509', '-newkey', 'ec', '-pkeyopt', 'ec_paramgen_curve:prime256v1', - '-keyout', DTLS_KEY_FILE, '-out', DTLS_CERT_FILE, - '-days', '365', '-nodes', '-subj', '/CN=facefusion' - ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - - -def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int, whip_url : str) -> Tuple[subprocess.Popen[bytes], int]: - create_dtls_certificate() - audio_read_fd, audio_write_fd = os.pipe() - commands = ffmpeg_builder.chain( - [ '-use_wallclock_as_timestamps', '1' ], - ffmpeg_builder.capture_video(), - ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)), - ffmpeg_builder.set_input('-'), - [ '-use_wallclock_as_timestamps', '1' ], - [ '-f', 's16le', '-ar', str(STREAM_AUDIO_RATE), '-ac', '2', '-i', 'pipe:' + str(audio_read_fd) ], - ffmpeg_builder.set_video_encoder('libx264'), - ffmpeg_builder.set_video_quality('libx264', stream_quality), - ffmpeg_builder.set_video_preset('libx264', 'ultrafast'), - [ '-pix_fmt', 'yuv420p' ], - [ '-profile:v', 'baseline' ], - [ '-tune', 'zerolatency' ], - [ '-maxrate', compute_bitrate(width, height) ], - [ '-bufsize', compute_bufsize(width, height) ], - [ '-g', str(stream_fps) ], - [ '-c:a', 'libopus' ], - [ '-f', 'whip' ], - [ '-cert_file', DTLS_CERT_FILE ], - [ '-key_file', DTLS_KEY_FILE ], - ffmpeg_builder.set_output(whip_url) - ) - commands = ffmpeg_builder.run(commands) - - if is_windows(): - os.set_inheritable(audio_read_fd, True) - process = subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE, close_fds = False) - else: - process = subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE, pass_fds = (audio_read_fd,)) - - os.close(audio_read_fd) - return process, audio_write_fd - - -def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None: - raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes() - process.stdin.write(raw_bytes) - process.stdin.flush() - - -def feed_whip_audio(audio_write_fd : int, audio_data : bytes) -> None: - os.write(audio_write_fd, audio_data) - - -def close_whip_encoder(process : subprocess.Popen[bytes], audio_write_fd : int) -> None: - os.close(audio_write_fd) - process.stdin.close() - process.terminate() - process.wait(timeout = 5) - - -def create_fmp4_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> Tuple[subprocess.Popen[bytes], int]: - audio_read_fd, audio_write_fd = os.pipe() - commands = ffmpeg_builder.chain( - [ '-use_wallclock_as_timestamps', '1' ], - ffmpeg_builder.capture_video(), - ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)), - ffmpeg_builder.set_input('-'), - [ '-use_wallclock_as_timestamps', '1' ], - [ '-f', 's16le', '-ar', str(STREAM_AUDIO_RATE), '-ac', '2', '-i', 'pipe:' + str(audio_read_fd) ], - [ '-thread_queue_size', '512' ], - ffmpeg_builder.set_video_encoder('libx264'), - ffmpeg_builder.set_video_quality('libx264', stream_quality), - ffmpeg_builder.set_video_preset('libx264', 'ultrafast'), - [ '-pix_fmt', 'yuv420p' ], - [ '-profile:v', 'baseline' ], - [ '-tune', 'zerolatency' ], - [ '-maxrate', compute_bitrate(width, height) ], - [ '-bufsize', compute_bufsize(width, height) ], - [ '-g', str(stream_fps) ], - [ '-c:a', 'aac' ], - [ '-b:a', '128k' ], - [ '-f', 'mp4' ], - [ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ], - ffmpeg_builder.set_output('-') - ) - commands = ffmpeg_builder.run(commands) - - if is_windows(): - os.set_inheritable(audio_read_fd, True) - process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE, close_fds = False) - else: - process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE, pass_fds = (audio_read_fd,)) - - os.close(audio_read_fd) - return process, audio_write_fd - - -def read_fmp4_output(process : subprocess.Popen[bytes], output_chunks : List[bytes], lock : threading.Lock) -> None: - while True: - chunk = process.stdout.read(4096) - - if not chunk: - break - - with lock: - output_chunks.append(chunk) - - -def collect_fmp4_chunks(output_chunks : List[bytes], lock : threading.Lock) -> Optional[bytes]: - with lock: - if output_chunks: - encoded_bytes = b''.join(output_chunks) - output_chunks.clear() - return encoded_bytes - - return None - - -def close_fmp4_encoder(process : subprocess.Popen[bytes], audio_write_fd : int) -> None: - if audio_write_fd > 0: - os.close(audio_write_fd) - process.stdin.close() - process.terminate() - process.wait(timeout = 5) - - -def create_rtp_encoder(width : int, height : int, stream_fps : int, stream_quality : int, rtp_port : int) -> subprocess.Popen[bytes]: - commands = ffmpeg_builder.chain( - [ '-use_wallclock_as_timestamps', '1' ], - ffmpeg_builder.capture_video(), - ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)), - ffmpeg_builder.set_input('-'), - ffmpeg_builder.set_video_encoder('libx264'), - ffmpeg_builder.set_video_quality('libx264', stream_quality), - ffmpeg_builder.set_video_preset('libx264', 'ultrafast'), - [ '-pix_fmt', 'yuv420p' ], - [ '-profile:v', 'baseline' ], - [ '-tune', 'zerolatency' ], - [ '-maxrate', compute_bitrate(width, height) ], - [ '-bufsize', compute_bufsize(width, height) ], - [ '-g', str(stream_fps) ], - [ '-an' ], - [ '-f', 'rtp' ], - [ '-payload_type', '96' ], - ffmpeg_builder.set_output('rtp://127.0.0.1:' + str(rtp_port) + '?pkt_size=1200') - ) - commands = ffmpeg_builder.run(commands) - process = subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE) - return process - - -def create_h264_pipe_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]: - commands = ffmpeg_builder.chain( - [ '-use_wallclock_as_timestamps', '1' ], - ffmpeg_builder.capture_video(), - ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)), - ffmpeg_builder.set_input('-'), - ffmpeg_builder.set_video_encoder('libx264'), - ffmpeg_builder.set_video_quality('libx264', stream_quality), - ffmpeg_builder.set_video_preset('libx264', 'ultrafast'), - [ '-pix_fmt', 'yuv420p' ], - [ '-profile:v', 'baseline' ], - [ '-tune', 'zerolatency' ], - [ '-maxrate', compute_bitrate(width, height) ], - [ '-bufsize', compute_bufsize(width, height) ], - [ '-g', '1' ], - [ '-an' ], - [ '-f', 'h264' ], - ffmpeg_builder.set_output('-') - ) - commands = ffmpeg_builder.run(commands) - process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE) - return process - - def create_vp8_pipe_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]: commands = ffmpeg_builder.chain( [ '-use_wallclock_as_timestamps', '1' ], @@ -255,5 +67,11 @@ def create_vp8_pipe_encoder(width : int, height : int, stream_fps : int, stream_ return process +def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None: + raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes() + process.stdin.write(raw_bytes) + process.stdin.flush() + + def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame: return process_vision_frame(vision_frame) diff --git a/facefusion/mediamtx.py b/facefusion/mediamtx.py deleted file mode 100644 index cabf4fcc..00000000 --- a/facefusion/mediamtx.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -import shutil -import subprocess -import time -from typing import Optional - -import httpx - -from facefusion.common_helper import is_linux - - -MEDIAMTX_WHIP_PORT : int = 8889 -MEDIAMTX_API_PORT : int = 9997 -MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'mediamtx.yml') -MEDIAMTX_FALLBACK_BINARY : str = '/home/henry/local/bin/mediamtx' -MEDIAMTX_PROCESS : Optional[subprocess.Popen[bytes]] = None - - -def get_whip_url(stream_path : str) -> str: - return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + stream_path + '/whip' - - -def get_whep_url(stream_path : str) -> str: - return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + stream_path + '/whep' - - -def get_api_url() -> str: - return 'http://localhost:' + str(MEDIAMTX_API_PORT) - - -def resolve_binary() -> str: - mediamtx_path = shutil.which('mediamtx') - - if mediamtx_path: - return mediamtx_path - return MEDIAMTX_FALLBACK_BINARY - - -def start() -> None: - global MEDIAMTX_PROCESS - - stop_stale() - mediamtx_binary = resolve_binary() - MEDIAMTX_PROCESS = subprocess.Popen( - [ mediamtx_binary, MEDIAMTX_CONFIG ], - stdout = subprocess.DEVNULL, - stderr = subprocess.DEVNULL - ) - - -def stop() -> None: - global MEDIAMTX_PROCESS - - if MEDIAMTX_PROCESS: - MEDIAMTX_PROCESS.terminate() - MEDIAMTX_PROCESS.wait() - MEDIAMTX_PROCESS = None - - -def stop_stale() -> None: - if is_linux(): - subprocess.run([ 'fuser', '-k', str(MEDIAMTX_WHIP_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', str(MEDIAMTX_API_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - time.sleep(1) - - -def wait_for_ready() -> bool: - api_url = get_api_url() + '/v3/paths/list' - - for _ in range(10): - try: - response = httpx.get(api_url, timeout = 1) - - if response.status_code == 200: - return True - except Exception: - pass - time.sleep(0.5) - return False - - -def is_path_ready(stream_path : str) -> bool: - api_url = get_api_url() + '/v3/paths/get/' + stream_path - - try: - response = httpx.get(api_url, timeout = 1) - - if response.status_code == 200: - return response.json().get('ready', False) - except Exception: - pass - return False - - -def add_path(stream_path : str) -> bool: - api_url = get_api_url() + '/v3/config/paths/add/' + stream_path - response = httpx.post(api_url, json = {}, timeout = 5) - - return response.status_code == 200 - - -def remove_path(stream_path : str) -> bool: - api_url = get_api_url() + '/v3/config/paths/delete/' + stream_path - response = httpx.delete(api_url, timeout = 5) - - return response.status_code == 200 diff --git a/facefusion/rtc.py b/facefusion/rtc.py index 6ee0495f..f0dda424 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -2,45 +2,24 @@ import ctypes import ctypes.util import os import threading -import time as _time -from http.server import BaseHTTPRequestHandler, HTTPServer +import time from typing import Dict, List, Optional, TypeAlias from facefusion import logger -from facefusion.common_helper import is_macos, is_windows RtcLib : TypeAlias = ctypes.CDLL -WHEP_PORT : int = 8892 lib : Optional[RtcLib] = None sessions : Dict[str, dict] = {} -http_thread : Optional[threading.Thread] = None -running : bool = False -RTC_NEW = 0 -RTC_CONNECTING = 1 RTC_CONNECTED = 2 -RTC_DISCONNECTED = 3 -RTC_FAILED = 4 -RTC_CLOSED = 5 - -RTC_GATHERING_NEW = 0 -RTC_GATHERING_INPROGRESS = 1 RTC_GATHERING_COMPLETE = 2 -RTC_DIRECTION_SENDONLY = 0 -RTC_DIRECTION_RECVONLY = 1 -RTC_DIRECTION_SENDRECV = 2 -RTC_DIRECTION_INACTIVE = 3 -RTC_DIRECTION_UNKNOWN = 4 - LOG_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) DESCRIPTION_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p) CANDIDATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p) STATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) GATHERING_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) -TRACK_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) -MESSAGE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) class RtcConfiguration(ctypes.Structure): @@ -85,11 +64,16 @@ def find_library() -> Optional[str]: project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) bin_dir = os.path.join(project_root, 'bin') - if is_windows(): - return os.path.join(bin_dir, 'windows-x64-openssl-h264-vp8-av1-opus-datachannel-0.24.1.dll') - if is_macos(): - return os.path.join(bin_dir, 'macos-universal-openssl-h264-vp8-av1-opus-libdatachannel-0.24.1.dylib') - return os.path.join(bin_dir, 'linux-x64-openssl-h264-vp8-av1-opus-libdatachannel-0.24.1.so') + if not os.path.isdir(bin_dir): + return None + + ext = '.dll' if os.name == 'nt' else '.so' + + for name in os.listdir(bin_dir): + if 'datachannel' in name and name.endswith(ext): + return os.path.join(bin_dir, name) + + return None def load_library() -> bool: @@ -121,9 +105,6 @@ def setup_prototypes() -> None: lib.rtcDeletePeerConnection.argtypes = [ctypes.c_int] lib.rtcDeletePeerConnection.restype = ctypes.c_int - lib.rtcSetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p] - lib.rtcSetLocalDescription.restype = ctypes.c_int - lib.rtcSetRemoteDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p] lib.rtcSetRemoteDescription.restype = ctypes.c_int @@ -133,9 +114,6 @@ def setup_prototypes() -> None: lib.rtcAddTrack.argtypes = [ctypes.c_int, ctypes.c_char_p] lib.rtcAddTrack.restype = ctypes.c_int - lib.rtcSetUserPointer.argtypes = [ctypes.c_int, ctypes.c_void_p] - lib.rtcSetUserPointer.restype = None - lib.rtcSetLocalDescriptionCallback.argtypes = [ctypes.c_int, DESCRIPTION_CALLBACK_TYPE] lib.rtcSetLocalDescriptionCallback.restype = ctypes.c_int @@ -148,18 +126,9 @@ def setup_prototypes() -> None: lib.rtcSetGatheringStateChangeCallback.argtypes = [ctypes.c_int, GATHERING_CALLBACK_TYPE] lib.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int - lib.rtcSetTrackCallback.argtypes = [ctypes.c_int, TRACK_CALLBACK_TYPE] - lib.rtcSetTrackCallback.restype = ctypes.c_int - - lib.rtcSetMessageCallback.argtypes = [ctypes.c_int, MESSAGE_CALLBACK_TYPE] - lib.rtcSetMessageCallback.restype = ctypes.c_int - lib.rtcSendMessage.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int] lib.rtcSendMessage.restype = ctypes.c_int - lib.rtcSetH264Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)] - lib.rtcSetH264Packetizer.restype = ctypes.c_int - lib.rtcSetVP8Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)] lib.rtcSetVP8Packetizer.restype = ctypes.c_int @@ -204,95 +173,11 @@ def create_peer_connection() -> int: return lib.rtcCreatePeerConnection(ctypes.byref(config)) -next_rtp_port : int = 16000 - - def create_session(stream_path : str) -> None: - global video_frame_count - video_frame_count = 0 - sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': 0, 'rtp_fd': None} - - -def create_rtp_session(stream_path : str) -> int: - global next_rtp_port - import socket as sock - - rtp_port = next_rtp_port - next_rtp_port += 1 - - rtp_fd = sock.socket(sock.AF_INET, sock.SOCK_DGRAM) - rtp_fd.bind(('127.0.0.1', rtp_port)) - rtp_fd.settimeout(1.0) - - sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': rtp_port, 'rtp_fd': rtp_fd} - - rtp_thread = threading.Thread(target = run_rtp_forwarder, args = (stream_path,), daemon = True) - rtp_thread.start() - - return rtp_port - - -def run_rtp_forwarder(stream_path : str) -> None: - session = sessions.get(stream_path) - - if not session: - return - - rtp_fd = session.get('rtp_fd') - - while running and session.get('rtp_fd'): - try: - data, addr = rtp_fd.recvfrom(262144) - - if len(data) < 2: - continue - - tag = data[0] - payload = data[1:] - - if tag == 0x01: - send_to_viewers(stream_path, payload) - if tag == 0x02: - send_audio_to_viewers(stream_path, payload) - except Exception: - continue - - -def send_audio_to_viewers(stream_path : str, opus_data : bytes) -> None: - global audio_pts - - session = sessions.get(stream_path) - - if not session: - return - - viewers = session.get('viewers') - - if not viewers: - return - - buf = ctypes.create_string_buffer(opus_data) - - for viewer in viewers: - if not viewer.get('connected'): - continue - - audio_track_id = viewer.get('audio_track') - - if not audio_track_id: - continue - - if not lib.rtcIsOpen(audio_track_id): - continue - - lib.rtcSetTrackRtpTimestamp(audio_track_id, audio_pts & 0xFFFFFFFF) - lib.rtcSendMessage(audio_track_id, buf, len(opus_data)) - - audio_pts += OPUS_FRAME_SAMPLES + sessions[stream_path] = {'viewers': []} send_start_time : float = 0 -video_frame_count : int = 0 audio_pts : int = 0 opus_enc = None audio_buffer : bytearray = bytearray() @@ -301,7 +186,7 @@ OPUS_FRAME_SAMPLES : int = 960 def send_to_viewers(stream_path : str, data : bytes) -> None: - global video_frame_count + global send_start_time session = sessions.get(stream_path) @@ -313,8 +198,11 @@ def send_to_viewers(stream_path : str, data : bytes) -> None: if not viewers: return - timestamp = int(video_frame_count * 3000) & 0xFFFFFFFF - video_frame_count += 1 + if send_start_time == 0: + send_start_time = time.monotonic() + + elapsed = time.monotonic() - send_start_time + timestamp = int(elapsed * 90000) & 0xFFFFFFFF buf = ctypes.create_string_buffer(data) data_len = len(data) @@ -370,10 +258,6 @@ def encode_opus_frame(pcm_data : bytes) -> Optional[bytes]: return None -def get_opus_encoder() -> None: - init_opus_encoder() - - def send_audio(stream_path : str, pcm_data : bytes) -> None: global audio_pts @@ -422,112 +306,9 @@ def send_audio(stream_path : str, pcm_data : bytes) -> None: audio_pts += OPUS_FRAME_SAMPLES -h264_au_buffer : Dict[str, bytes] = {} - - -def send_vp8_frame(stream_path : str, frame_data : bytes) -> None: - send_h264_frame(stream_path, frame_data) - - -def find_nal_starts(data : bytes) -> List: - starts = [] - i = 0 - - while i < len(data) - 3: - if data[i] == 0 and data[i + 1] == 0: - if data[i + 2] == 1: - starts.append((i, 3)) - i += 3 - continue - if i < len(data) - 4 and data[i + 2] == 0 and data[i + 3] == 1: - starts.append((i, 4)) - i += 4 - continue - - i += 1 - - return starts - - -def send_h264_frame(stream_path : str, frame_data : bytes) -> None: - global send_start_time - - session = sessions.get(stream_path) - - if not session: - return - - viewers = session.get('viewers') - - if not viewers: - return - - prev = h264_au_buffer.get(stream_path, b'') - buf = prev + frame_data - nal_starts = find_nal_starts(buf) - - if len(nal_starts) < 2: - h264_au_buffer[stream_path] = buf - return - - au_boundaries = [] - - for idx, (pos, sc_len) in enumerate(nal_starts): - nal_type = buf[pos + sc_len] & 0x1f - - if nal_type == 7: - au_boundaries.append(idx) - - if len(au_boundaries) < 2: - h264_au_buffer[stream_path] = buf - return - - if send_start_time == 0: - send_start_time = _time.monotonic() - - elapsed = _time.monotonic() - send_start_time - frame_duration = 1.0 / 30.0 - - for k in range(len(au_boundaries) - 1): - start_nal = au_boundaries[k] - end_nal = au_boundaries[k + 1] - timestamp = int((elapsed + k * frame_duration) * 90000) & 0xFFFFFFFF - - nalu_parts = [] - - for nal_idx in range(start_nal, end_nal): - nal_pos = nal_starts[nal_idx][0] - nal_sc_len = nal_starts[nal_idx][1] - - if nal_idx + 1 < len(nal_starts): - nal_end = nal_starts[nal_idx + 1][0] - else: - nal_end = len(buf) - - nalu = buf[nal_pos + nal_sc_len:nal_end] - - if len(nalu) > 0: - nalu_parts.append(len(nalu).to_bytes(4, 'big') + nalu) - - if nalu_parts: - frame_msg = b''.join(nalu_parts) - - for viewer in viewers: - tracks = viewer.get('tracks', []) - - if tracks: - lib.rtcSendMessage(tracks[0], frame_msg, len(frame_msg)) - - last_boundary = au_boundaries[-1] - h264_au_buffer[stream_path] = buf[nal_starts[last_boundary][0]:] - - def destroy_session(stream_path : str) -> None: session = sessions.pop(stream_path, None) - if not session: - return - for viewer in session.get('viewers', []): pc_id = viewer.get('pc') @@ -535,26 +316,8 @@ def destroy_session(stream_path : str) -> None: lib.rtcDeletePeerConnection(pc_id) -def send_data(stream_path : str, data : bytes) -> None: - session = sessions.get(stream_path) - - if not session: - return - - for viewer in session.get('viewers', []): - for track_id in viewer.get('tracks', []): - lib.rtcSendMessage(track_id, data, len(data)) - - def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]: session = sessions.get(stream_path) - - if not session: - return None - - if not lib: - return None - pc = create_peer_connection() gathering_done = threading.Event() local_sdp_holder = [None] @@ -642,73 +405,9 @@ def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]: def start() -> None: - global running, http_thread - - if running: - return - - if not load_library(): - return - - running = True - http_thread = threading.Thread(target = run_http_server, daemon = True) - http_thread.start() - logger.info('rtc whep server started on port ' + str(WHEP_PORT), __name__) + load_library() def stop() -> None: - global running - - running = False - for stream_path in list(sessions.keys()): destroy_session(stream_path) - - -def run_http_server() -> None: - class WhepHandler(BaseHTTPRequestHandler): - def log_message(self, format, *args) -> None: - pass - - def send_cors_headers(self) -> None: - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS') - self.send_header('Access-Control-Allow-Headers', 'Content-Type') - - def do_OPTIONS(self) -> None: - self.send_response(200) - self.send_cors_headers() - self.end_headers() - - def do_POST(self) -> None: - path = self.path - - if not path.endswith('/whep'): - self.send_response(404) - self.send_cors_headers() - self.end_headers() - return - - stream_path = path[1:].rsplit('/whep', 1)[0] - content_length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(content_length).decode('utf-8') if content_length else '' - answer = handle_whep_offer(stream_path, body) - - if answer: - self.send_response(201) - self.send_header('Content-Type', 'application/sdp') - self.send_header('Location', path) - self.send_cors_headers() - self.end_headers() - self.wfile.write(answer.encode('utf-8')) - return - - self.send_response(404) - self.send_cors_headers() - self.end_headers() - - server = HTTPServer(('0.0.0.0', WHEP_PORT), WhepHandler) - server.timeout = 1 - - while running: - server.handle_request() diff --git a/facefusion/webrtc_sfu.py b/facefusion/webrtc_sfu.py deleted file mode 100644 index 55e2f057..00000000 --- a/facefusion/webrtc_sfu.py +++ /dev/null @@ -1,546 +0,0 @@ -import binascii -import hashlib -import os -import socket -import struct -import threading -import time -from typing import Dict, Optional, Tuple, TypeAlias - -import pylibsrtp -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import ec -from OpenSSL import SSL - -from facefusion import logger - -SrtpSession : TypeAlias = pylibsrtp.Session -SrtpPolicy : TypeAlias = pylibsrtp.Policy - -WHIP_PORT : int = 8890 -ICE_UFRAG_LENGTH : int = 4 -ICE_PWD_LENGTH : int = 22 -RTP_HEADER_SIZE : int = 12 - -SRTP_PROFILES =\ -[ - { - 'name': b'SRTP_AES128_CM_SHA1_80', - 'libsrtp': SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80, - 'key_len': 16, - 'salt_len': 14 - } -] - -sessions : Dict[str, dict] = {} -server_cert = None -server_key = None -server_fingerprint : str = '' -udp_socket : Optional[socket.socket] = None -http_thread : Optional[threading.Thread] = None -udp_thread : Optional[threading.Thread] = None -running : bool = False - - -def generate_credentials() -> None: - global server_cert, server_key, server_fingerprint - - server_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) - name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, binascii.hexlify(os.urandom(16)).decode())]) - import datetime - now = datetime.datetime.now(tz = datetime.timezone.utc) - builder = x509.CertificateBuilder().subject_name(name).issuer_name(name).public_key(server_key.public_key()).serial_number(x509.random_serial_number()).not_valid_before(now - datetime.timedelta(days = 1)).not_valid_after(now + datetime.timedelta(days = 30)) - server_cert = builder.sign(server_key, hashes.SHA256(), default_backend()) - fp = server_cert.fingerprint(hashes.SHA256()).hex().upper() - server_fingerprint = ':'.join(fp[i:i + 2] for i in range(0, len(fp), 2)) - - -def generate_ice_credentials() -> Tuple[str, str]: - ufrag = binascii.hexlify(os.urandom(ICE_UFRAG_LENGTH)).decode() - pwd = binascii.hexlify(os.urandom(ICE_PWD_LENGTH)).decode()[:ICE_PWD_LENGTH] - return ufrag, pwd - - -def parse_sdp_offer(sdp : str) -> dict: - result = {'ice_ufrag': '', 'ice_pwd': '', 'fingerprint': '', 'setup': '', 'media': [], 'candidates': []} - current_media = None - - for line in sdp.splitlines(): - line = line.strip() - - if line.startswith('a=ice-ufrag:'): - result['ice_ufrag'] = line.split(':', 1)[1] - if line.startswith('a=ice-pwd:'): - result['ice_pwd'] = line.split(':', 1)[1] - if line.startswith('a=fingerprint:'): - result['fingerprint'] = line.split(' ', 1)[1] if ' ' in line else '' - if line.startswith('a=setup:'): - result['setup'] = line.split(':', 1)[1] - if line.startswith('a=candidate:'): - result['candidates'].append(line[12:]) - if line.startswith('m='): - parts = line[2:].split() - current_media = {'kind': parts[0], 'port': int(parts[1]), 'profile': parts[2], 'formats': parts[3:], 'codec_lines': [], 'mid': None} - result['media'].append(current_media) - if current_media: - if line.startswith('a=rtpmap:') or line.startswith('a=fmtp:') or line.startswith('a=rtcp-fb:'): - current_media['codec_lines'].append(line) - if line.startswith('a=mid:'): - current_media['mid'] = line.split(':', 1)[1] - if line.startswith('a=extmap:'): - current_media['codec_lines'].append(line) - - return result - - -def build_sdp_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int) -> str: - lines = [] - lines.append('v=0') - lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1') - lines.append('s=-') - lines.append('t=0 0') - - mids = [] - for i, media in enumerate(offer.get('media', [])): - mids.append(str(i)) - - if mids: - lines.append('a=group:BUNDLE ' + ' '.join(mids)) - - lines.append('a=ice-lite') - - for i, media in enumerate(offer.get('media', [])): - kind = media.get('kind') - formats = media.get('formats', []) - profile = media.get('profile', 'UDP/TLS/RTP/SAVPF') - mid = media.get('mid', str(i)) - lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats)) - lines.append('c=IN IP4 127.0.0.1') - lines.append('a=rtcp:9 IN IP4 0.0.0.0') - lines.append('a=ice-ufrag:' + local_ufrag) - lines.append('a=ice-pwd:' + local_pwd) - lines.append('a=ice-options:ice2') - lines.append('a=fingerprint:sha-256 ' + server_fingerprint) - lines.append('a=setup:passive') - lines.append('a=mid:' + mid) - lines.append('a=rtcp-mux') - lines.append('a=recvonly') - - for codec_line in media.get('codec_lines', []): - lines.append(codec_line) - - lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host') - lines.append('a=end-of-candidates') - - return '\r\n'.join(lines) + '\r\n' - - -def build_whep_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int, ingest_offer : dict) -> str: - lines = [] - lines.append('v=0') - lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1') - lines.append('s=-') - lines.append('t=0 0') - - mids = [] - for i, media in enumerate(offer.get('media', [])): - mid = media.get('mid', str(i)) - mids.append(mid) - - if mids: - lines.append('a=group:BUNDLE ' + ' '.join(mids)) - - lines.append('a=ice-lite') - - for i, media in enumerate(offer.get('media', [])): - kind = media.get('kind') - formats = media.get('formats', []) - profile = media.get('profile', 'UDP/TLS/RTP/SAVPF') - mid = media.get('mid', str(i)) - lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats)) - lines.append('c=IN IP4 127.0.0.1') - lines.append('a=rtcp:9 IN IP4 0.0.0.0') - lines.append('a=ice-ufrag:' + local_ufrag) - lines.append('a=ice-pwd:' + local_pwd) - lines.append('a=ice-options:ice2') - lines.append('a=fingerprint:sha-256 ' + server_fingerprint) - lines.append('a=setup:passive') - lines.append('a=mid:' + mid) - lines.append('a=rtcp-mux') - lines.append('a=sendonly') - - for codec_line in media.get('codec_lines', []): - lines.append(codec_line) - - lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host') - lines.append('a=end-of-candidates') - - return '\r\n'.join(lines) + '\r\n' - - -def create_ssl_context() -> SSL.Context: - ctx = SSL.Context(SSL.DTLS_METHOD) - ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: True) - ctx.use_certificate(server_cert) - ctx.use_privatekey(server_key) - ctx.set_cipher_list(b'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-SHA') - ctx.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80') - return ctx - - -def create_session(stream_path : str) -> None: - ufrag, pwd = generate_ice_credentials() - sessions[stream_path] = { - 'ice_ufrag': ufrag, - 'ice_pwd': pwd, - 'ingest': None, - 'viewers': [], - 'ingest_offer': None, - 'rx_srtp': None, - 'tx_sessions': [] - } - - -def destroy_session(stream_path : str) -> None: - sessions.pop(stream_path, None) - - -def handle_whip(stream_path : str, sdp_offer : str) -> Optional[str]: - session = sessions.get(stream_path) - - if not session: - return None - - offer = parse_sdp_offer(sdp_offer) - session['ingest_offer'] = offer - local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT - answer = build_sdp_answer(offer, session.get('ice_ufrag'), session.get('ice_pwd'), local_port) - return answer - - -def handle_whep(stream_path : str, sdp_offer : str) -> Optional[str]: - session = sessions.get(stream_path) - - if not session: - return None - - offer = parse_sdp_offer(sdp_offer) - viewer_ufrag, viewer_pwd = generate_ice_credentials() - local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT - ingest_offer = session.get('ingest_offer', offer) - answer = build_whep_answer(offer, viewer_ufrag, viewer_pwd, local_port, ingest_offer) - session['viewers'].append({'offer': offer, 'ice_ufrag': viewer_ufrag, 'ice_pwd': viewer_pwd}) - return answer - - -def start() -> None: - global running, udp_socket, http_thread - - generate_credentials() - running = True - - udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - udp_socket.bind(('0.0.0.0', WHIP_PORT)) - udp_socket.settimeout(1.0) - - udp_thread_instance = threading.Thread(target = run_udp_loop, daemon = True) - udp_thread_instance.start() - - http_thread = threading.Thread(target = run_http_server, daemon = True) - http_thread.start() - logger.info('webrtc sfu started on port ' + str(WHIP_PORT), __name__) - - -def stop() -> None: - global running, udp_socket - - running = False - - if udp_socket: - udp_socket.close() - udp_socket = None - - -dtls_connections : Dict[tuple, dict] = {} - - -def run_udp_loop() -> None: - while running: - try: - data, addr = udp_socket.recvfrom(2048) - - if not data: - continue - - first_byte = data[0] - - if first_byte == 0 or first_byte == 1: - handle_stun(data, addr) - if first_byte > 19 and first_byte < 64: - handle_dtls(data, addr) - if first_byte > 127 and first_byte < 192: - handle_srtp(data, addr) - - except socket.timeout: - continue - except Exception: - if running: - continue - - -def handle_dtls(data : bytes, addr : tuple) -> None: - conn = dtls_connections.get(addr) - - if not conn: - ctx = create_ssl_context() - ssl_conn = SSL.Connection(ctx) - ssl_conn.set_accept_state() - conn = {'ssl': ssl_conn, 'encrypted': False, 'rx_srtp': None, 'tx_srtp': None} - dtls_connections[addr] = conn - - ssl_conn = conn.get('ssl') - ssl_conn.bio_write(data) - - try: - if not conn.get('encrypted'): - try: - ssl_conn.do_handshake() - conn['encrypted'] = True - setup_srtp_session(conn) - logger.info('dtls handshake complete with ' + str(addr), __name__) - except SSL.WantReadError: - pass - else: - try: - ssl_conn.recv(1500) - except SSL.ZeroReturnError: - pass - except SSL.Error: - pass - except Exception: - pass - - flush_dtls(ssl_conn, addr) - - -def flush_dtls(ssl_conn : SSL.Connection, addr : tuple) -> None: - try: - outdata = ssl_conn.bio_read(1500) - - if outdata: - udp_socket.sendto(outdata, addr) - except SSL.Error: - pass - - -def setup_srtp_session(conn : dict) -> None: - ssl_conn = conn.get('ssl') - ssl_conn.get_selected_srtp_profile() - key_len = 16 - salt_len = 14 - view = ssl_conn.export_keying_material(b'EXTRACTOR-dtls_srtp', 2 * (key_len + salt_len)) - server_key = view[key_len:2 * key_len] + view[2 * key_len + salt_len:] - client_key = view[:key_len] + view[2 * key_len:2 * key_len + salt_len] - - rx_policy = SrtpPolicy(key = client_key, ssrc_type = SrtpPolicy.SSRC_ANY_INBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80) - rx_policy.allow_repeat_tx = True - rx_policy.window_size = 1024 - conn['rx_srtp'] = SrtpSession(rx_policy) - - tx_policy = SrtpPolicy(key = server_key, ssrc_type = SrtpPolicy.SSRC_ANY_OUTBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80) - tx_policy.allow_repeat_tx = True - tx_policy.window_size = 1024 - conn['tx_srtp'] = SrtpSession(tx_policy) - - -def is_rtcp(data : bytes) -> bool: - if len(data) < 2: - return False - pt = data[1] & 0x7F - return 64 <= pt <= 95 - - -def handle_srtp(data : bytes, addr : tuple) -> None: - conn = dtls_connections.get(addr) - - if not conn or not conn.get('rx_srtp'): - return - - try: - if is_rtcp(data): - plain = conn.get('rx_srtp').unprotect_rtcp(data) - else: - plain = conn.get('rx_srtp').unprotect(data) - - forward_rtp(plain, addr) - except Exception: - pass - - -def forward_rtp(data : bytes, source_addr : tuple) -> None: - for other_addr, conn in dtls_connections.items(): - if other_addr == source_addr: - continue - - if not conn.get('tx_srtp'): - continue - - try: - if is_rtcp(data): - encrypted = conn.get('tx_srtp').protect_rtcp(data) - else: - encrypted = conn.get('tx_srtp').protect(data) - udp_socket.sendto(encrypted, other_addr) - except Exception: - pass - - -def handle_stun(data : bytes, addr : tuple) -> None: - if len(data) < 20: - return - - msg_type = struct.unpack('!H', data[0:2])[0] - - if msg_type != 0x0001: - return - - msg_length = struct.unpack('!H', data[2:4])[0] - transaction_id = data[8:20] - - username = None - offset = 20 - - while offset < 20 + msg_length: - if offset + 4 > len(data): - break - attr_type = struct.unpack('!H', data[offset:offset + 2])[0] - attr_length = struct.unpack('!H', data[offset + 2:offset + 4])[0] - attr_value = data[offset + 4:offset + 4 + attr_length] - - if attr_type == 0x0006: - username = attr_value.decode('utf-8', errors = 'ignore') - - padded = attr_length + (4 - attr_length % 4) % 4 - offset += 4 + padded - - if not username: - return - - local_ufrag = username.split(':')[0] if ':' in username else username - session_pwd = None - - for session in sessions.values(): - if session.get('ice_ufrag') == local_ufrag: - session_pwd = session.get('ice_pwd') - break - - for viewer in session.get('viewers', []): - if viewer.get('ice_ufrag') == local_ufrag: - session_pwd = viewer.get('ice_pwd') - break - - if session_pwd: - break - - if not session_pwd: - return - - response = build_stun_response(transaction_id, addr, session_pwd) - udp_socket.sendto(response, addr) - - -def build_stun_response(transaction_id : bytes, addr : tuple, password : str) -> bytes: - import hmac - import zlib - - magic_cookie = 0x2112A442 - magic_bytes = struct.pack('!I', magic_cookie) - - xport = addr[1] ^ (magic_cookie >> 16) - ip_int = struct.unpack('!I', socket.inet_aton(addr[0]))[0] - xip = struct.pack('!I', ip_int ^ magic_cookie) - xor_addr_value = struct.pack('!BBH', 0, 0x01, xport) + xip - xor_addr_attr = struct.pack('!HH', 0x0020, len(xor_addr_value)) + xor_addr_value - - attrs_before_integrity = xor_addr_attr - integrity_dummy_len = len(attrs_before_integrity) + 4 + 20 - header_for_hmac = struct.pack('!HH', 0x0101, integrity_dummy_len) + magic_bytes + transaction_id - key = password.encode('utf-8') - integrity = hmac.new(key, header_for_hmac + attrs_before_integrity, hashlib.sha1).digest() - integrity_attr = struct.pack('!HH', 0x0008, 20) + integrity - - attrs_before_fp = attrs_before_integrity + integrity_attr - fp_dummy_len = len(attrs_before_fp) + 4 + 4 - header_for_fp = struct.pack('!HH', 0x0101, fp_dummy_len) + magic_bytes + transaction_id - crc = zlib.crc32(header_for_fp + attrs_before_fp) ^ 0x5354554E - fingerprint_attr = struct.pack('!HHI', 0x8028, 4, crc & 0xFFFFFFFF) - - all_attrs = attrs_before_integrity + integrity_attr + fingerprint_attr - header = struct.pack('!HH', 0x0101, len(all_attrs)) + magic_bytes + transaction_id - return header + all_attrs - - -def run_http_server() -> None: - from http.server import HTTPServer, BaseHTTPRequestHandler - - class WhipWhepHandler(BaseHTTPRequestHandler): - def log_message(self, format, *args) -> None: - pass - - def do_POST(self) -> None: - path = self.path - content_length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(content_length).decode('utf-8') if content_length else '' - - if path.endswith('/whip'): - stream_path = path[1:].rsplit('/whip', 1)[0] - answer = handle_whip(stream_path, body) - - if answer: - self.send_response(201) - self.send_header('Content-Type', 'application/sdp') - self.send_header('Location', path) - self.send_header('Access-Control-Allow-Origin', '*') - self.end_headers() - self.wfile.write(answer.encode('utf-8')) - return - - self.send_response(404) - self.end_headers() - return - - if path.endswith('/whep'): - stream_path = path[1:].rsplit('/whep', 1)[0] - answer = handle_whep(stream_path, body) - - if answer: - self.send_response(201) - self.send_header('Content-Type', 'application/sdp') - self.send_header('Location', path) - self.send_header('Access-Control-Allow-Origin', '*') - self.end_headers() - self.wfile.write(answer.encode('utf-8')) - return - - self.send_response(404) - self.end_headers() - return - - self.send_response(404) - self.end_headers() - - def do_OPTIONS(self) -> None: - self.send_response(200) - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS') - self.send_header('Access-Control-Allow-Headers', 'Content-Type') - self.end_headers() - - server = HTTPServer(('0.0.0.0', WHIP_PORT), WhipWhepHandler) - server.timeout = 1 - - while running: - server.handle_request() diff --git a/facefusion/whip_relay.py b/facefusion/whip_relay.py deleted file mode 100644 index d3d4ce95..00000000 --- a/facefusion/whip_relay.py +++ /dev/null @@ -1,62 +0,0 @@ -import threading -from typing import Optional - -from facefusion import logger - -RELAY_PORT : int = 8891 -_started : bool = False -_lock : threading.Lock = threading.Lock() - - -def get_whip_url(stream_path : str) -> str: - from facefusion import rtc - return 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whip' - - -def get_whep_url(stream_path : str) -> str: - from facefusion import rtc - return 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep' - - -def start() -> None: - global _started - - from facefusion import rtc - - if not rtc.lib: - if not rtc.load_library(): - logger.warn('whip relay: libdatachannel not available', __name__) - return - - if not rtc.running: - rtc.start() - - _started = True - logger.info('whip relay (python) ready on port ' + str(rtc.WHEP_PORT), __name__) - - -def stop() -> None: - global _started - _started = False - - -def wait_for_ready() -> bool: - return _started - - -def is_session_ready(stream_path : str) -> bool: - from facefusion import rtc - return stream_path in rtc.sessions - - -def create_session(stream_path : str) -> int: - from facefusion import rtc - - if not _started: - start() - - if not rtc.lib: - return 0 - - rtp_port = rtc.create_rtp_session(stream_path) - return rtp_port diff --git a/mediamtx.yml b/mediamtx.yml deleted file mode 100644 index 51ea060c..00000000 --- a/mediamtx.yml +++ /dev/null @@ -1,9 +0,0 @@ -rtsp: no -rtmp: no -hls: no -srt: no -webrtc: yes -webrtcAddress: :8889 -api: yes -apiAddress: :9997 -paths: diff --git a/test_stream.html b/test_stream.html index ad441f43..20e2067a 100644 --- a/test_stream.html +++ b/test_stream.html @@ -139,24 +139,7 @@
-
4 Streaming Mode
-
- -
-
- -
-
5 Options
+
4 Options