From 6047463154b91f0265f28f04e8be4f8114a0cef4 Mon Sep 17 00:00:00 2001 From: Harisreedhar <46858047+harisreedhar@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:26:10 +0530 Subject: [PATCH] Replace aiortc with libdatachannel direct pipeline (#1083) * fix stdin close error * Refactor stream endpoint, fix encoder thread safety and improve tests * fix and improve test * remove not None * use Enum * use Enum and add todo * remove poll --- facefusion/apis/core.py | 6 +- facefusion/apis/endpoints/stream.py | 73 +++----------- facefusion/apis/stream_helper.py | 148 +++++++++++++++++++--------- facefusion/rtc_store.py | 28 +++--- facefusion/types.py | 1 - requirements.txt | 1 - tests/stream_helper.py | 65 +++++++++--- tests/test_api_stream.py | 28 ++++-- tests/test_stream_helper.py | 28 +----- 9 files changed, 208 insertions(+), 170 deletions(-) diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 5be40df6..effd7341 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state -from facefusion.apis.endpoints.stream import webrtc_stream, websocket_stream +from facefusion.apis.endpoints.stream import post_stream, websocket_stream from facefusion.apis.middlewares.session import create_session_guard @@ -29,10 +29,10 @@ def create_api() -> Starlette: Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]), Route('/capabilities', get_capabilities, methods = [ 'GET' ]), Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]), - Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]), + Route('/stream', post_stream, methods = [ 'POST' ], middleware = [ session_guard ]), WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]), - WebSocketRoute('/stream', websocket_stream, middleware = [session_guard]) + WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]) ] api = Starlette(routes = routes) diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index f2789520..3191113c 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -1,76 +1,33 @@ -from functools import partial - -import cv2 -import numpy -from aiortc import RTCPeerConnection, RTCSessionDescription from starlette.requests import Request -from starlette.responses import JSONResponse, Response -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from starlette.responses import Response +from starlette.status import HTTP_201_CREATED, HTTP_404_NOT_FOUND from starlette.websockets import WebSocket -from facefusion import session_context, session_manager, state_manager -from facefusion.apis.api_helper import get_sec_websocket_protocol +from facefusion import rtc_store, session_context, session_manager from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import create_output_track, on_video_track -from facefusion.streamer import process_vision_frame +from facefusion.apis.stream_helper import get_websocket_stream_mode, handle_image_stream, handle_video_stream async def websocket_stream(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) + stream_mode = get_websocket_stream_mode(websocket.scope) - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') + if stream_mode == 'image': + await handle_image_stream(websocket) - await websocket.accept(subprotocol = subprotocol) - - if source_paths: - try: - image_buffer = await websocket.receive_bytes() - target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(target_vision_frame): - temp_vision_frame = process_vision_frame(target_vision_frame) - is_success, output_vision_frame = cv2.imencode('.jpg', temp_vision_frame) - - if is_success: - await websocket.send_bytes(output_vision_frame.tobytes()) - - except Exception: - pass - return - - await websocket.close() + if stream_mode == 'video': + await handle_video_stream(websocket) -async def webrtc_stream(request : Request) -> Response: +async def post_stream(request : Request) -> Response: access_token = extract_access_token(request.scope) session_id = session_manager.find_session_id(access_token) session_context.set_session_id(session_id) if session_id: - body = await request.json() - buffer_size = int(body.get('buffer_size', 30)) - bitrate_init = int(body.get('bitrate_init', 100000)) - bitrate_min = int(body.get('bitrate_min', 100000)) - bitrate_max = int(body.get('bitrate_max', 4000000)) + sdp_offer = (await request.body()).decode() + sdp_answer = rtc_store.add_rtc_viewer(session_id, sdp_offer) - rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type')) - rtc_connection = RTCPeerConnection() + if sdp_answer: + return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp') - output_track, sender = create_output_track(rtc_connection, buffer_size) - sender.configure_bitrate(bitrate_init, bitrate_min, bitrate_max) - - rtc_connection.on('track', partial(on_video_track, rtc_connection, output_track)) - - await rtc_connection.setRemoteDescription(rtc_offer) - await rtc_connection.setLocalDescription(await rtc_connection.createAnswer()) - - return JSONResponse( - { - 'sdp': rtc_connection.localDescription.sdp, - 'type': rtc_connection.localDescription.type - }) - - return Response(status_code = HTTP_500_INTERNAL_SERVER_ERROR) + return Response(status_code = HTTP_404_NOT_FOUND) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index e235084d..b1f12905 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -2,53 +2,23 @@ import asyncio import math import os import subprocess -from typing import Iterator, Optional, Tuple, cast +from collections import deque +from collections.abc import AsyncIterator +from typing import Optional, cast -from aiortc import MediaStreamTrack, QueuedVideoStreamTrack, RTCPeerConnection, RTCRtpSender -from aiortc.mediastreams import MediaStreamError -from av import VideoFrame +import cv2 +import numpy from starlette.datastructures import Headers from starlette.types import Scope +from starlette.websockets import WebSocket, WebSocketState +from facefusion import rtc_store, session_context, session_manager, state_manager +from facefusion.apis.api_helper import get_sec_websocket_protocol +from facefusion.apis.session_helper import extract_access_token from facefusion.common_helper import is_linux, is_macos +from facefusion.ffmpeg import spawn_stream from facefusion.streamer import process_vision_frame -from facefusion.types import Resolution, StreamBuffer, WebSocketStreamMode - - -def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame: - target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24') - output_vision_frame = process_vision_frame(target_vision_frame) - output_stream_frame = VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24') - output_stream_frame.pts = target_stream_frame.pts - output_stream_frame.time_base = target_stream_frame.time_base - return output_stream_frame - - -def create_output_track(rtc_connection : RTCPeerConnection, buffer_size : int) -> Tuple[QueuedVideoStreamTrack, RTCRtpSender]: - output_track = QueuedVideoStreamTrack(buffer_size = buffer_size) - sender = rtc_connection.addTrack(output_track) - return output_track, sender - - -async def process_and_enqueue(target_track : MediaStreamTrack, output_track : QueuedVideoStreamTrack) -> None: - loop = asyncio.get_running_loop() - - while True: - try: - target_stream_frame = await target_track.recv() - except MediaStreamError: - pass - - output_stream_frame = await loop.run_in_executor(None, process_stream_frame, target_stream_frame) #type:ignore[arg-type] - await output_track.put(output_stream_frame) - - -def on_video_track(rtc_connection : RTCPeerConnection, output_track : QueuedVideoStreamTrack, target_track : MediaStreamTrack) -> None: - if target_track.kind == 'audio': - rtc_connection.addTrack(target_track) - - if target_track.kind == 'video': - asyncio.create_task(process_and_enqueue(target_track, output_track)) +from facefusion.types import Resolution, SessionId, VisionFrame, WebSocketStreamMode def calculate_bitrate(resolution : Resolution) -> int: # TODO : improve the bitrate calculation @@ -89,8 +59,21 @@ def read_pipe_buffer(pipe_handle : int, size : int) -> Optional[bytes]: return None -def forward_stream_frame(process : subprocess.Popen[bytes]) -> Iterator[StreamBuffer]: - pipe_handle = process.stdout.fileno() +async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]: + websocket_event = await websocket.receive() + + while websocket_event.get('type') == 'websocket.receive': + frame_buffer = websocket_event.get('bytes') or b'' + vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) + + if numpy.any(vision_frame): + yield vision_frame + + websocket_event = await websocket.receive() + + +def forward_rtc_frames(encoder : subprocess.Popen[bytes], session_id : SessionId) -> None: + pipe_handle = encoder.stdout.fileno() if is_linux() or is_macos(): os.set_blocking(pipe_handle, True) @@ -105,6 +88,83 @@ def forward_stream_frame(process : subprocess.Popen[bytes]) -> Iterator[StreamBu frame_data = read_pipe_buffer(pipe_handle, frame_size) if frame_data: - yield frame_data + rtc_store.send_rtc_frame(session_id, frame_data) frame_header = read_pipe_buffer(pipe_handle, 12) + + +def submit_encoder_frame(encoder : subprocess.Popen[bytes], vision_frame_deque : deque[VisionFrame]) -> None: + output_vision_frame = process_vision_frame(vision_frame_deque[-1]) + encoder.stdin.write(cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2RGB).tobytes()) + encoder.stdin.flush() + + +def run_encode_loop(encoder : subprocess.Popen[bytes], vision_frame_deque : deque[VisionFrame]) -> None: + while vision_frame_deque: + submit_encoder_frame(encoder, vision_frame_deque) + + encoder.stdin.close() + encoder.wait() + + +async def handle_image_stream(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + access_token = extract_access_token(websocket.scope) + session_id = session_manager.find_session_id(access_token) + session_context.set_session_id(session_id) + source_paths = state_manager.get_item('source_paths') + + await websocket.accept(subprotocol = subprotocol) + + if source_paths: + capture_vision_frame = await anext(receive_vision_frames(websocket), None) + + if numpy.any(capture_vision_frame): + output_vision_frame = process_vision_frame(capture_vision_frame) + is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame) + + if is_success: + await websocket.send_bytes(output_frame_buffer.tobytes()) + + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close() + + +async def handle_video_stream(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + access_token = extract_access_token(websocket.scope) + session_id = session_manager.find_session_id(access_token) + session_context.set_session_id(session_id) + source_paths = state_manager.get_item('source_paths') + + await websocket.accept(subprotocol = subprotocol) + + if session_id and source_paths: + output_video_fps = int(state_manager.get_item('output_video_fps') or 30) # TODO: resolve from target video fps + vision_frames = receive_vision_frames(websocket) + vision_frame = await anext(vision_frames, None) + + if numpy.any(vision_frame): + resolution = (vision_frame.shape[1], vision_frame.shape[0]) + encoder = spawn_stream(resolution, output_video_fps, calculate_bitrate(resolution), calculate_buffer_size(resolution)) + + vision_frame_deque : deque[VisionFrame] = deque(maxlen = 1) + + vision_frame_deque.append(vision_frame) + rtc_store.create_rtc_stream(session_id) + + event_loop = asyncio.get_running_loop() + await event_loop.run_in_executor(None, submit_encoder_frame, encoder, vision_frame_deque) + await websocket.send_text('ready') + encode_task = event_loop.run_in_executor(None, run_encode_loop, encoder, vision_frame_deque) + rtc_task = event_loop.run_in_executor(None, forward_rtc_frames, encoder, session_id) + + async for vision_frame in vision_frames: + vision_frame_deque.append(vision_frame) + + vision_frame_deque.clear() + await asyncio.gather(encode_task, rtc_task) + rtc_store.destroy_rtc_stream(session_id) + + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close() diff --git a/facefusion/rtc_store.py b/facefusion/rtc_store.py index fad822f7..4be9e271 100644 --- a/facefusion/rtc_store.py +++ b/facefusion/rtc_store.py @@ -1,37 +1,35 @@ from typing import List, Optional from facefusion import rtc -from facefusion.types import RtcPeer, RtcSdpAnswer, RtcSdpOffer, RtcStreamStore +from facefusion.types import RtcPeer, RtcSdpAnswer, RtcSdpOffer, RtcStreamStore, SessionId -RTC_STREAMS : RtcStreamStore = {} # TODO: tie lifetime to session_id so streams are cleaned up on session expiry +RTC_STREAMS : RtcStreamStore = {} -def get_rtc_stream(stream_path : str) -> Optional[List[RtcPeer]]: - return RTC_STREAMS.get(stream_path) +def get_rtc_stream(session_id : SessionId) -> Optional[List[RtcPeer]]: + return RTC_STREAMS.get(session_id) -def create_rtc_stream(stream_path : str) -> None: - RTC_STREAMS[stream_path] = [] +def create_rtc_stream(session_id : SessionId) -> None: + RTC_STREAMS[session_id] = [] -def destroy_rtc_stream(stream_path : str) -> None: - peers = RTC_STREAMS.pop(stream_path, None) +def destroy_rtc_stream(session_id : SessionId) -> None: + peers = RTC_STREAMS.pop(session_id, None) if peers: rtc.delete_peers(peers) -def add_rtc_viewer(stream_path : str, sdp_offer : RtcSdpOffer) -> Optional[RtcSdpAnswer]: - peers = get_rtc_stream(stream_path) - - if peers: - return rtc.handle_whep_offer(peers, sdp_offer) +def add_rtc_viewer(session_id : SessionId, sdp_offer : RtcSdpOffer) -> Optional[RtcSdpAnswer]: + if session_id in RTC_STREAMS: + return rtc.handle_whep_offer(RTC_STREAMS.get(session_id), sdp_offer) return None -def send_rtc_frame(stream_path : str, frame_data : bytes) -> None: - peers = get_rtc_stream(stream_path) +def send_rtc_frame(session_id : SessionId, frame_data : bytes) -> None: + peers = get_rtc_stream(session_id) if peers: rtc.send_to_peers(peers, frame_data) diff --git a/facefusion/types.py b/facefusion/types.py index f75b02c6..165d1856 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -265,7 +265,6 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet', WebcamMode = Literal['inline', 'udp', 'v4l2'] StreamMode = Literal['udp', 'v4l2'] WebSocketStreamMode = Literal['image', 'video'] -StreamBuffer : TypeAlias = bytes RtcOfferSet = TypedDict('RtcOfferSet', { diff --git a/requirements.txt b/requirements.txt index afc8942f..826b791b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ nvidia-ml-py==13.590.48 psutil==7.2.2 tqdm==4.67.3 scipy==1.16.3 -aiortc @ git+https://github.com/facefusion/aiortc.git@feat/dynamic-bitrate starlette==0.52.1 uvicorn==0.41.0 websockets==16.0 diff --git a/tests/stream_helper.py b/tests/stream_helper.py index d185481a..b60bb9ff 100644 --- a/tests/stream_helper.py +++ b/tests/stream_helper.py @@ -1,21 +1,58 @@ -from aiortc import RTCPeerConnection, VideoStreamTrack +import ctypes +import os +import threading +import time +from typing import Optional -from facefusion.types import RtcOfferSet +from starlette.testclient import TestClient + +from facefusion import rtc +from facefusion.types import RtcSdpOffer -async def create_rtc_offer() -> RtcOfferSet: - rtc_connection = RTCPeerConnection() - rtc_connection.addTrack(VideoStreamTrack()) - rtc_offer = await rtc_connection.createOffer() +def create_sdp_offer() -> Optional[RtcSdpOffer]: + rtc_library = rtc.create_static_rtc_library() + peer_connection = rtc.create_peer_connection(disable_auto_negotiation = True) - await rtc_connection.setLocalDescription(rtc_offer) + media_video = os.linesep.join( + [ + 'm=video 9 UDP/TLS/RTP/SAVPF 96', + 'a=rtpmap:96 VP8/90000', + 'a=recvonly', + 'a=mid:0', + '' + ]).encode() + media_audio = os.linesep.join( + [ + 'm=audio 9 UDP/TLS/RTP/SAVPF 111', + 'a=rtpmap:111 opus/48000/2', + 'a=recvonly', + 'a=mid:1', + '' + ]).encode() - rtc_offer_set : RtcOfferSet =\ - { - 'sdp': rtc_connection.localDescription.sdp, - 'type': rtc_connection.localDescription.type - } + rtc_library.rtcAddTrack(peer_connection, media_video) + rtc_library.rtcAddTrack(peer_connection, media_audio) + rtc_library.rtcSetLocalDescription(peer_connection, b'offer') - await rtc_connection.close() + buffer_size = 16384 + buffer_string = ctypes.create_string_buffer(buffer_size) + wait_limit = time.monotonic() + 5 - return rtc_offer_set + while time.monotonic() < wait_limit: + if rtc_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0: + sdp = buffer_string.value.decode() + rtc_library.rtcDeletePeerConnection(peer_connection) + return sdp + time.sleep(0.05) + + rtc_library.rtcDeletePeerConnection(peer_connection) + return None + + +def open_websocket_stream(test_client : TestClient, subprotocols : list[str], source_content : bytes, ready_event : threading.Event, stop_event : threading.Event) -> None: + with test_client.websocket_connect('/stream', subprotocols = subprotocols) as websocket: + websocket.send_bytes(source_content) + websocket.receive_text() + ready_event.set() + stop_event.wait() diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index 1f862d88..ad2813e8 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -1,5 +1,5 @@ -import asyncio import tempfile +import threading from typing import Iterator import cv2 @@ -13,7 +13,7 @@ from facefusion.apis.core import create_api from facefusion.core import common_pre_check, processors_pre_check from facefusion.download import conditional_download from .assert_helper import get_test_example_file, get_test_examples_directory -from .stream_helper import create_rtc_offer +from .stream_helper import create_sdp_offer, open_websocket_stream @pytest.fixture(scope = 'module', autouse = True) @@ -92,7 +92,8 @@ def test_stream_image(test_client : TestClient) -> None: with test_client.websocket_connect('/stream', subprotocols = [ - 'access_token.' + access_token + 'access_token.' + access_token, + 'image' ]) as websocket: websocket.send_bytes(source_content) output_bytes = websocket.receive_bytes() @@ -129,12 +130,21 @@ def test_stream_video(test_client : TestClient) -> None: 'Authorization': 'Bearer ' + access_token }) - rtc_offer = asyncio.run(create_rtc_offer()) - stream_response = test_client.post('/stream', json = rtc_offer, headers = + ready_event = threading.Event() + stop_event = threading.Event() + stream_thread = threading.Thread(target = open_websocket_stream, args = (test_client, [ 'access_token.' + access_token, 'video' ], source_content, ready_event, stop_event)) + stream_thread.start() + ready_event.wait() + + sdp_offer = create_sdp_offer() + stream_response = test_client.post('/stream', content = sdp_offer, headers = { - 'Authorization': 'Bearer ' + access_token + 'Authorization': 'Bearer ' + access_token, + 'Content-Type': 'application/sdp' }) - assert stream_response.status_code == 200 - assert stream_response.json().get('type') == 'answer' - assert stream_response.json().get('sdp') + assert stream_response.status_code == 201 + assert stream_response.text + + stop_event.set() + stream_thread.join() diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py index efb0be28..f7644712 100644 --- a/tests/test_stream_helper.py +++ b/tests/test_stream_helper.py @@ -1,9 +1,6 @@ import os -import subprocess -from facefusion import ffmpeg_builder -from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, forward_stream_frame, get_websocket_stream_mode, read_pipe_buffer -from facefusion.vision import pack_resolution +from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, get_websocket_stream_mode, read_pipe_buffer def make_scope(protocol : str) -> dict[str, object]: @@ -30,7 +27,7 @@ def test_calculate_buffer_size() -> None: assert calculate_buffer_size((3840, 2160)) == 14000 -def test_get_websocket_stream_mode() -> None: +def test_get_stream_mode() -> None: assert get_websocket_stream_mode(make_scope('image')) == 'image' assert get_websocket_stream_mode(make_scope('video')) == 'video' @@ -47,23 +44,4 @@ def test_read_pipe_buffer() -> None: os.close(read_fd) -def test_forward_frames() -> None: - resolution = (320, 240) - frame_size = resolution[0] * resolution[1] * 3 - commands = ffmpeg_builder.run(ffmpeg_builder.chain( - ffmpeg_builder.capture_video(), - ffmpeg_builder.set_media_resolution(pack_resolution(resolution)), - ffmpeg_builder.set_input_fps(30), - ffmpeg_builder.set_input('-'), - ffmpeg_builder.set_video_encoder('libvpx'), - ffmpeg_builder.set_encoder_deadline('realtime'), - ffmpeg_builder.set_stream_quality(400), - ffmpeg_builder.set_muxer('ivf'), - ffmpeg_builder.set_output('-') - )) - encoder = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE) - encoder.stdin.write(bytes(frame_size)) - encoder.stdin.close() - - for stream_buffer in forward_stream_frame(encoder): - assert 0 < len(stream_buffer) < frame_size +# TODO: add remaining tests