From 87c2eebb2d9a667a16f3e7a798b2929bbeee2e86 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 20 Mar 2026 10:49:32 +0100 Subject: [PATCH] experiment to run video via websocket --- facefusion/apis/core.py | 6 +-- facefusion/apis/endpoints/stream.py | 60 ++++++++++++++-------- facefusion/apis/stream_helper.py | 79 ++++++++++++++++++++--------- facefusion/types.py | 6 --- requirements.txt | 1 - tests/stream_helper.py | 26 +++------- tests/test_api_stream.py | 19 +++---- 7 files changed, 113 insertions(+), 84 deletions(-) diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 5be40df6..5b683372 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state -from facefusion.apis.endpoints.stream import webrtc_stream, websocket_stream +from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_live from facefusion.apis.middlewares.session import create_session_guard @@ -29,10 +29,10 @@ def create_api() -> Starlette: Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]), Route('/capabilities', get_capabilities, methods = [ 'GET' ]), Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]), - Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]), WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]), - WebSocketRoute('/stream', websocket_stream, middleware = [session_guard]) + WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]), + WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ]) ] api = Starlette(routes = routes) diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 71de97da..c63eefa5 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -1,17 +1,14 @@ -from functools import partial +import asyncio +import threading import cv2 import numpy -from aiortc import RTCPeerConnection, RTCSessionDescription -from starlette.requests import Request -from starlette.responses import JSONResponse, Response -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from starlette.websockets import WebSocket from facefusion import session_context, session_manager, state_manager from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import on_video_track +from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_stream_encoder, create_stream_encoder, encode_stream_frame, process_stream_frame, read_stream_output from facefusion.streamer import process_vision_frame @@ -44,25 +41,46 @@ async def websocket_stream(websocket : WebSocket) -> None: await websocket.close() -async def webrtc_stream(request : Request) -> Response: - access_token = extract_access_token(request.scope) +async def websocket_stream_live(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + access_token = extract_access_token(websocket.scope) session_id = session_manager.find_session_id(access_token) + session_context.set_session_id(session_id) + source_paths = state_manager.get_item('source_paths') - if session_id: - body = await request.json() - rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type')) - rtc_connection = RTCPeerConnection() + await websocket.accept(subprotocol = subprotocol) - rtc_connection.on('track', partial(on_video_track, rtc_connection)) + if source_paths: + encoder = None + reader_thread = None + output_chunks = [] + lock = threading.Lock() - await rtc_connection.setRemoteDescription(rtc_offer) - await rtc_connection.setLocalDescription(await rtc_connection.createAnswer()) + try: + while True: + image_buffer = await websocket.receive_bytes() + target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR) - return JSONResponse( - { - 'sdp': rtc_connection.localDescription.sdp, - 'type': rtc_connection.localDescription.type - }) + if numpy.any(target_vision_frame): + temp_vision_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_vision_frame) - return Response(status_code = HTTP_500_INTERNAL_SERVER_ERROR) + if not encoder: + height, width = temp_vision_frame.shape[:2] + encoder = create_stream_encoder(width, height, STREAM_FPS, STREAM_QUALITY) + reader_thread = threading.Thread(target = read_stream_output, args = (encoder, output_chunks, lock), daemon = True) + reader_thread.start() + + encoded_bytes = encode_stream_frame(encoder, temp_vision_frame, output_chunks, lock) + + if encoded_bytes: + await websocket.send_bytes(encoded_bytes) + + except Exception: + pass + + if encoder: + close_stream_encoder(encoder) + return + + await websocket.close() diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 24c2825e..9ced63a7 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,36 +1,67 @@ -import asyncio -from typing import cast +import subprocess +import threading +from typing import Optional -from aiortc import MediaStreamTrack, RTCPeerConnection, VideoStreamTrack -from av import VideoFrame +import cv2 +from facefusion import ffmpeg_builder +from facefusion.ffmpeg import open_ffmpeg from facefusion.streamer import process_vision_frame +from facefusion.types import VisionFrame + +STREAM_FPS : int = 30 +STREAM_QUALITY : int = 80 -def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame: - target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24') - output_vision_frame = process_vision_frame(target_vision_frame) - return VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24') +def create_stream_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]: + commands = ffmpeg_builder.chain( + ffmpeg_builder.capture_video(), + ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)), + ffmpeg_builder.set_input_fps(stream_fps), + ffmpeg_builder.set_input('-'), + ffmpeg_builder.set_video_encoder('libx264'), + ffmpeg_builder.set_video_quality('libx264', stream_quality), + ffmpeg_builder.set_video_preset('libx264', 'ultrafast'), + [ '-tune', 'zerolatency' ], + [ '-maxrate', '4000k' ], + [ '-bufsize', '8000k' ], + [ '-g', str(stream_fps) ], + [ '-f', 'mp4' ], + [ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ], + ffmpeg_builder.set_output('-') + ) + return open_ffmpeg(commands) -def create_output_track(target_track : MediaStreamTrack) -> VideoStreamTrack: - output_track = VideoStreamTrack() +def read_stream_output(process : subprocess.Popen[bytes], output_chunks : list, lock : threading.Lock) -> None: + while True: + chunk = process.stdout.read(4096) - async def read_stream_frame() -> VideoFrame: - target_stream_frame = cast(VideoFrame, await target_track.recv()) - output_stream_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_stream_frame) - output_stream_frame.pts = target_stream_frame.pts - output_stream_frame.time_base = target_stream_frame.time_base - return output_stream_frame + if not chunk: + break - output_track.recv = read_stream_frame - return output_track + with lock: + output_chunks.append(chunk) -def on_video_track(rtc_connection : RTCPeerConnection, target_track : MediaStreamTrack) -> None: - if target_track.kind == 'audio': - rtc_connection.addTrack(target_track) +def encode_stream_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame, output_chunks : list, lock : threading.Lock) -> Optional[bytes]: + raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes() + process.stdin.write(raw_bytes) + process.stdin.flush() - if target_track.kind == 'video': - output_track = create_output_track(target_track) - rtc_connection.addTrack(output_track) + with lock: + if output_chunks: + encoded_bytes = b''.join(output_chunks) + output_chunks.clear() + return encoded_bytes + + return None + + +def close_stream_encoder(process : subprocess.Popen[bytes]) -> None: + process.stdin.close() + process.wait() + + +def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame: + return process_vision_frame(vision_frame) diff --git a/facefusion/types.py b/facefusion/types.py index f9137325..0b7d985a 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -259,12 +259,6 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet', WebcamMode = Literal['inline', 'udp', 'v4l2'] StreamMode = Literal['udp', 'v4l2'] -RtcOfferSet = TypedDict('RtcOfferSet', -{ - 'sdp': str, - 'type': str -}) - ModelOptions : TypeAlias = Dict[str, Any] ModelSet : TypeAlias = Dict[str, ModelOptions] ModelInitializer : TypeAlias = NDArray[Any] diff --git a/requirements.txt b/requirements.txt index 238feb42..826b791b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ nvidia-ml-py==13.590.48 psutil==7.2.2 tqdm==4.67.3 scipy==1.16.3 -aiortc==1.14.0 starlette==0.52.1 uvicorn==0.41.0 websockets==16.0 diff --git a/tests/stream_helper.py b/tests/stream_helper.py index d185481a..6a16ef36 100644 --- a/tests/stream_helper.py +++ b/tests/stream_helper.py @@ -1,21 +1,11 @@ -from aiortc import RTCPeerConnection, VideoStreamTrack - -from facefusion.types import RtcOfferSet +import cv2 +import numpy -async def create_rtc_offer() -> RtcOfferSet: - rtc_connection = RTCPeerConnection() - rtc_connection.addTrack(VideoStreamTrack()) - rtc_offer = await rtc_connection.createOffer() +def create_test_frame_bytes(width : int, height : int) -> bytes: + vision_frame = numpy.zeros((height, width, 3), dtype = numpy.uint8) + is_success, image_buffer = cv2.imencode('.jpg', vision_frame) - await rtc_connection.setLocalDescription(rtc_offer) - - rtc_offer_set : RtcOfferSet =\ - { - 'sdp': rtc_connection.localDescription.sdp, - 'type': rtc_connection.localDescription.type - } - - await rtc_connection.close() - - return rtc_offer_set + if is_success: + return image_buffer.tobytes() + return b'' diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index 1f862d88..aa60468d 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -1,4 +1,3 @@ -import asyncio import tempfile from typing import Iterator @@ -13,7 +12,6 @@ from facefusion.apis.core import create_api from facefusion.core import common_pre_check, processors_pre_check from facefusion.download import conditional_download from .assert_helper import get_test_example_file, get_test_examples_directory -from .stream_helper import create_rtc_offer @pytest.fixture(scope = 'module', autouse = True) @@ -101,7 +99,7 @@ def test_stream_image(test_client : TestClient) -> None: assert output_vision_frame.shape == (1024, 1024, 3) -def test_stream_video(test_client : TestClient) -> None: +def test_stream_live(test_client : TestClient) -> None: create_session_response = test_client.post('/session', json = { 'client_version': metadata.get('version') @@ -129,12 +127,11 @@ def test_stream_video(test_client : TestClient) -> None: 'Authorization': 'Bearer ' + access_token }) - rtc_offer = asyncio.run(create_rtc_offer()) - stream_response = test_client.post('/stream', json = rtc_offer, headers = - { - 'Authorization': 'Bearer ' + access_token - }) + with test_client.websocket_connect('/stream/live', subprotocols = + [ + 'access_token.' + access_token + ]) as websocket: + websocket.send_bytes(source_content) + output_bytes = websocket.receive_bytes() - assert stream_response.status_code == 200 - assert stream_response.json().get('type') == 'answer' - assert stream_response.json().get('sdp') + assert len(output_bytes) > 0