diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 71de97da..f2789520 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -11,7 +11,7 @@ from starlette.websockets import WebSocket from facefusion import session_context, session_manager, state_manager from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import on_video_track +from facefusion.apis.stream_helper import create_output_track, on_video_track from facefusion.streamer import process_vision_frame @@ -51,10 +51,18 @@ async def webrtc_stream(request : Request) -> Response: if session_id: body = await request.json() + buffer_size = int(body.get('buffer_size', 30)) + bitrate_init = int(body.get('bitrate_init', 100000)) + bitrate_min = int(body.get('bitrate_min', 100000)) + bitrate_max = int(body.get('bitrate_max', 4000000)) + rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type')) rtc_connection = RTCPeerConnection() - rtc_connection.on('track', partial(on_video_track, rtc_connection)) + output_track, sender = create_output_track(rtc_connection, buffer_size) + sender.configure_bitrate(bitrate_init, bitrate_min, bitrate_max) + + rtc_connection.on('track', partial(on_video_track, rtc_connection, output_track)) await rtc_connection.setRemoteDescription(rtc_offer) await rtc_connection.setLocalDescription(await rtc_connection.createAnswer()) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 24c2825e..14b9fdb4 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,7 +1,8 @@ import asyncio -from typing import cast +from typing import Tuple -from aiortc import MediaStreamTrack, RTCPeerConnection, VideoStreamTrack +from aiortc import MediaStreamTrack, QueuedVideoStreamTrack, RTCPeerConnection, RTCRtpSender +from aiortc.mediastreams import MediaStreamError from av import VideoFrame from facefusion.streamer import process_vision_frame @@ -10,27 +11,34 @@ from facefusion.streamer import process_vision_frame def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame: target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24') output_vision_frame = process_vision_frame(target_vision_frame) - return VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24') + output_stream_frame = VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24') + output_stream_frame.pts = target_stream_frame.pts + output_stream_frame.time_base = target_stream_frame.time_base + return output_stream_frame -def create_output_track(target_track : MediaStreamTrack) -> VideoStreamTrack: - output_track = VideoStreamTrack() - - async def read_stream_frame() -> VideoFrame: - target_stream_frame = cast(VideoFrame, await target_track.recv()) - output_stream_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_stream_frame) - output_stream_frame.pts = target_stream_frame.pts - output_stream_frame.time_base = target_stream_frame.time_base - return output_stream_frame - - output_track.recv = read_stream_frame - return output_track +def create_output_track(rtc_connection : RTCPeerConnection, buffer_size : int) -> Tuple[QueuedVideoStreamTrack, RTCRtpSender]: + output_track = QueuedVideoStreamTrack(buffer_size = buffer_size) + sender = rtc_connection.addTrack(output_track) + return output_track, sender -def on_video_track(rtc_connection : RTCPeerConnection, target_track : MediaStreamTrack) -> None: +async def process_and_enqueue(target_track : MediaStreamTrack, output_track : QueuedVideoStreamTrack) -> None: + loop = asyncio.get_running_loop() + + while True: + try: + target_stream_frame = await target_track.recv() + except MediaStreamError: + pass + + output_stream_frame = await loop.run_in_executor(None, process_stream_frame, target_stream_frame) #type:ignore[arg-type] + await output_track.put(output_stream_frame) + + +def on_video_track(rtc_connection : RTCPeerConnection, output_track : QueuedVideoStreamTrack, target_track : MediaStreamTrack) -> None: if target_track.kind == 'audio': rtc_connection.addTrack(target_track) if target_track.kind == 'video': - output_track = create_output_track(target_track) - rtc_connection.addTrack(output_track) + asyncio.create_task(process_and_enqueue(target_track, output_track)) diff --git a/requirements.txt b/requirements.txt index 238feb42..afc8942f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ nvidia-ml-py==13.590.48 psutil==7.2.2 tqdm==4.67.3 scipy==1.16.3 -aiortc==1.14.0 +aiortc @ git+https://github.com/facefusion/aiortc.git@feat/dynamic-bitrate starlette==0.52.1 uvicorn==0.41.0 websockets==16.0