From b740ff577a21b156665123e58ff4aca5f3a6794f Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 20 Mar 2026 13:06:39 +0100 Subject: [PATCH] move to whip --- facefusion/apis/core.py | 4 +- facefusion/apis/endpoints/stream.py | 98 +++++++++++++++++++------ facefusion/apis/stream_helper.py | 106 ++++++++++++++++++++-------- tests/test_api_stream.py | 7 +- 4 files changed, 157 insertions(+), 58 deletions(-) diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 5b683372..471694e9 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state -from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_live +from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_whip from facefusion.apis.middlewares.session import create_session_guard @@ -32,7 +32,7 @@ def create_api() -> Starlette: WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]), WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]), - WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ]) + WebSocketRoute('/stream/whip', websocket_stream_whip, middleware = [ session_guard ]) ] api = Starlette(routes = routes) diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index c63eefa5..80c561ac 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -1,15 +1,20 @@ import asyncio import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque import cv2 import numpy from starlette.websockets import WebSocket -from facefusion import session_context, session_manager, state_manager +from facefusion import logger, session_context, session_manager, state_manager from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_stream_encoder, create_stream_encoder, encode_stream_frame, process_stream_frame, read_stream_output +from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_whip_encoder, create_whip_encoder, feed_whip_frame, process_stream_frame, start_mediamtx, stop_mediamtx, wait_for_mediamtx from facefusion.streamer import process_vision_frame +from facefusion.types import VisionFrame async def websocket_stream(websocket : WebSocket) -> None: @@ -41,7 +46,49 @@ async def websocket_stream(websocket : WebSocket) -> None: await websocket.close() -async def websocket_stream_live(websocket : WebSocket) -> None: +def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event) -> None: + encoder = None + output_deque : Deque[VisionFrame] = deque() + + with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: + futures = [] + + while not stop_event.is_set(): + with lock: + capture_frame = latest_frame_holder[0] + latest_frame_holder[0] = None + + if capture_frame is not None: + future = executor.submit(process_stream_frame, capture_frame) + futures.append(future) + + for future_done in [ future for future in futures if future.done() ]: + output_deque.append(future_done.result()) + futures.remove(future_done) + + while output_deque: + temp_vision_frame = output_deque.popleft() + + if not encoder: + height, width = temp_vision_frame.shape[:2] + encoder = create_whip_encoder(width, height, STREAM_FPS, STREAM_QUALITY) + logger.info('whip encoder started ' + str(width) + 'x' + str(height), __name__) + + feed_whip_frame(encoder, temp_vision_frame) + + if capture_frame is None and not output_deque: + time.sleep(0.005) + + if encoder: + stderr_output = encoder.stderr.read() if encoder.stderr else b'' + + if stderr_output: + logger.error('ffmpeg: ' + stderr_output.decode(), __name__) + + close_whip_encoder(encoder) + + +async def websocket_stream_whip(websocket : WebSocket) -> None: subprotocol = get_sec_websocket_protocol(websocket.scope) access_token = extract_access_token(websocket.scope) session_id = session_manager.find_session_id(access_token) @@ -52,35 +99,40 @@ async def websocket_stream_live(websocket : WebSocket) -> None: await websocket.accept(subprotocol = subprotocol) if source_paths: - encoder = None - reader_thread = None - output_chunks = [] + mediamtx = start_mediamtx() + is_ready = await asyncio.get_running_loop().run_in_executor(None, wait_for_mediamtx) + + if not is_ready: + logger.error('mediamtx failed to start', __name__) + stop_mediamtx(mediamtx) + await websocket.close() + return + + logger.info('mediamtx ready', __name__) + + latest_frame_holder : list = [None] lock = threading.Lock() + stop_event = threading.Event() + worker = threading.Thread(target = run_whip_pipeline, args = (latest_frame_holder, lock, stop_event), daemon = True) + worker.start() try: while True: image_buffer = await websocket.receive_bytes() - target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR) + frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR) - if numpy.any(target_vision_frame): - temp_vision_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_vision_frame) + if numpy.any(frame): + with lock: + latest_frame_holder[0] = frame - if not encoder: - height, width = temp_vision_frame.shape[:2] - encoder = create_stream_encoder(width, height, STREAM_FPS, STREAM_QUALITY) - reader_thread = threading.Thread(target = read_stream_output, args = (encoder, output_chunks, lock), daemon = True) - reader_thread.start() + except Exception as exception: + logger.error(str(exception), __name__) - encoded_bytes = encode_stream_frame(encoder, temp_vision_frame, output_chunks, lock) + stop_event.set() + worker.join(timeout = 10) - if encoded_bytes: - await websocket.send_bytes(encoded_bytes) - - except Exception: - pass - - if encoder: - close_stream_encoder(encoder) + if mediamtx: + stop_mediamtx(mediamtx) return await websocket.close() diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 9ced63a7..821982f2 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,66 +1,114 @@ +import os +import shutil import subprocess -import threading +import tempfile +import time from typing import Optional import cv2 +import requests from facefusion import ffmpeg_builder -from facefusion.ffmpeg import open_ffmpeg from facefusion.streamer import process_vision_frame from facefusion.types import VisionFrame STREAM_FPS : int = 30 -STREAM_QUALITY : int = 80 +STREAM_QUALITY : int = 45 +MEDIAMTX_WHIP_PORT : int = 8889 +MEDIAMTX_PATH : str = 'stream' +MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'mediamtx.yml') +DTLS_CERT_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_cert.pem') +DTLS_KEY_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_key.pem') -def create_stream_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]: +def create_dtls_certificate() -> None: + if os.path.isfile(DTLS_CERT_FILE) and os.path.isfile(DTLS_KEY_FILE): + return + + subprocess.run([ + 'openssl', 'req', '-x509', '-newkey', 'ec', '-pkeyopt', 'ec_paramgen_curve:prime256v1', + '-keyout', DTLS_KEY_FILE, '-out', DTLS_CERT_FILE, + '-days', '365', '-nodes', '-subj', '/CN=facefusion' + ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + + +def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]: + create_dtls_certificate() + whip_url = 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + MEDIAMTX_PATH + '/whip' commands = ffmpeg_builder.chain( + [ '-use_wallclock_as_timestamps', '1' ], ffmpeg_builder.capture_video(), ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)), - ffmpeg_builder.set_input_fps(stream_fps), ffmpeg_builder.set_input('-'), + [ '-f', 'lavfi', '-i', 'anullsrc=r=48000:cl=stereo' ], ffmpeg_builder.set_video_encoder('libx264'), ffmpeg_builder.set_video_quality('libx264', stream_quality), ffmpeg_builder.set_video_preset('libx264', 'ultrafast'), + [ '-pix_fmt', 'yuv420p' ], + [ '-profile:v', 'baseline' ], [ '-tune', 'zerolatency' ], - [ '-maxrate', '4000k' ], - [ '-bufsize', '8000k' ], + [ '-maxrate', '1500k' ], + [ '-bufsize', '3000k' ], [ '-g', str(stream_fps) ], - [ '-f', 'mp4' ], - [ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ], - ffmpeg_builder.set_output('-') + [ '-c:a', 'libopus' ], + [ '-f', 'whip' ], + [ '-cert_file', DTLS_CERT_FILE ], + [ '-key_file', DTLS_KEY_FILE ], + ffmpeg_builder.set_output(whip_url) ) - return open_ffmpeg(commands) + commands = ffmpeg_builder.run(commands) + return subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE) -def read_stream_output(process : subprocess.Popen[bytes], output_chunks : list, lock : threading.Lock) -> None: - while True: - chunk = process.stdout.read(4096) +def start_mediamtx() -> Optional[subprocess.Popen[bytes]]: + stop_stale_mediamtx() + mediamtx_path = shutil.which('mediamtx') - if not chunk: - break + if not mediamtx_path: + mediamtx_path = '/home/henry/local/bin/mediamtx' - with lock: - output_chunks.append(chunk) + return subprocess.Popen( + [ mediamtx_path, MEDIAMTX_CONFIG ], + stdout = subprocess.DEVNULL, + stderr = subprocess.DEVNULL + ) -def encode_stream_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame, output_chunks : list, lock : threading.Lock) -> Optional[bytes]: +def stop_stale_mediamtx() -> None: + subprocess.run([ 'fuser', '-k', str(MEDIAMTX_WHIP_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + subprocess.run([ 'fuser', '-k', '9997/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + time.sleep(1) + + +def wait_for_mediamtx() -> bool: + for _ in range(10): + try: + response = requests.get('http://localhost:9997/v3/paths/list', timeout = 1) + + if response.status_code == 200: + return True + except Exception: + pass + time.sleep(0.5) + return False + + +def stop_mediamtx(process : subprocess.Popen[bytes]) -> None: + process.terminate() + process.wait() + + +def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None: raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes() process.stdin.write(raw_bytes) process.stdin.flush() - with lock: - if output_chunks: - encoded_bytes = b''.join(output_chunks) - output_chunks.clear() - return encoded_bytes - return None - - -def close_stream_encoder(process : subprocess.Popen[bytes]) -> None: +def close_whip_encoder(process : subprocess.Popen[bytes]) -> None: process.stdin.close() - process.wait() + process.terminate() + process.wait(timeout = 5) def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame: diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index aa60468d..b9fc4f18 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -99,7 +99,7 @@ def test_stream_image(test_client : TestClient) -> None: assert output_vision_frame.shape == (1024, 1024, 3) -def test_stream_live(test_client : TestClient) -> None: +def test_stream_whip(test_client : TestClient) -> None: create_session_response = test_client.post('/session', json = { 'client_version': metadata.get('version') @@ -127,11 +127,10 @@ def test_stream_live(test_client : TestClient) -> None: 'Authorization': 'Bearer ' + access_token }) - with test_client.websocket_connect('/stream/live', subprotocols = + with test_client.websocket_connect('/stream/whip', subprotocols = [ 'access_token.' + access_token ]) as websocket: websocket.send_bytes(source_content) - output_bytes = websocket.receive_bytes() - assert len(output_bytes) > 0 + assert True