Replace aiortc with libdatachannel direct pipeline (#1083)

* fix stdin close error

* Refactor stream endpoint, fix encoder thread safety and improve tests

* fix and improve test

* remove not None

* use Enum

* use Enum and add todo

* remove poll
This commit is contained in:
Harisreedhar
2026-04-30 18:26:10 +05:30
committed by GitHub
parent cb086e9437
commit 6047463154
9 changed files with 208 additions and 170 deletions
+3 -3
View File
@@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
from facefusion.apis.endpoints.stream import webrtc_stream, websocket_stream
from facefusion.apis.endpoints.stream import post_stream, websocket_stream
from facefusion.apis.middlewares.session import create_session_guard
@@ -29,10 +29,10 @@ def create_api() -> Starlette:
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
Route('/stream', post_stream, methods = [ 'POST' ], middleware = [ session_guard ]),
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ])
]
api = Starlette(routes = routes)
+15 -58
View File
@@ -1,76 +1,33 @@
from functools import partial
import cv2
import numpy
from aiortc import RTCPeerConnection, RTCSessionDescription
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from starlette.responses import Response
from starlette.status import HTTP_201_CREATED, HTTP_404_NOT_FOUND
from starlette.websockets import WebSocket
from facefusion import session_context, session_manager, state_manager
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion import rtc_store, session_context, session_manager
from facefusion.apis.session_helper import extract_access_token
from facefusion.apis.stream_helper import create_output_track, on_video_track
from facefusion.streamer import process_vision_frame
from facefusion.apis.stream_helper import get_websocket_stream_mode, handle_image_stream, handle_video_stream
async def websocket_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
stream_mode = get_websocket_stream_mode(websocket.scope)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
if stream_mode == 'image':
await handle_image_stream(websocket)
await websocket.accept(subprotocol = subprotocol)
if source_paths:
try:
image_buffer = await websocket.receive_bytes()
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(target_vision_frame):
temp_vision_frame = process_vision_frame(target_vision_frame)
is_success, output_vision_frame = cv2.imencode('.jpg', temp_vision_frame)
if is_success:
await websocket.send_bytes(output_vision_frame.tobytes())
except Exception:
pass
return
await websocket.close()
if stream_mode == 'video':
await handle_video_stream(websocket)
async def webrtc_stream(request : Request) -> Response:
async def post_stream(request : Request) -> Response:
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
if session_id:
body = await request.json()
buffer_size = int(body.get('buffer_size', 30))
bitrate_init = int(body.get('bitrate_init', 100000))
bitrate_min = int(body.get('bitrate_min', 100000))
bitrate_max = int(body.get('bitrate_max', 4000000))
sdp_offer = (await request.body()).decode()
sdp_answer = rtc_store.add_rtc_viewer(session_id, sdp_offer)
rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type'))
rtc_connection = RTCPeerConnection()
if sdp_answer:
return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp')
output_track, sender = create_output_track(rtc_connection, buffer_size)
sender.configure_bitrate(bitrate_init, bitrate_min, bitrate_max)
rtc_connection.on('track', partial(on_video_track, rtc_connection, output_track))
await rtc_connection.setRemoteDescription(rtc_offer)
await rtc_connection.setLocalDescription(await rtc_connection.createAnswer())
return JSONResponse(
{
'sdp': rtc_connection.localDescription.sdp,
'type': rtc_connection.localDescription.type
})
return Response(status_code = HTTP_500_INTERNAL_SERVER_ERROR)
return Response(status_code = HTTP_404_NOT_FOUND)
+104 -44
View File
@@ -2,53 +2,23 @@ import asyncio
import math
import os
import subprocess
from typing import Iterator, Optional, Tuple, cast
from collections import deque
from collections.abc import AsyncIterator
from typing import Optional, cast
from aiortc import MediaStreamTrack, QueuedVideoStreamTrack, RTCPeerConnection, RTCRtpSender
from aiortc.mediastreams import MediaStreamError
from av import VideoFrame
import cv2
import numpy
from starlette.datastructures import Headers
from starlette.types import Scope
from starlette.websockets import WebSocket, WebSocketState
from facefusion import rtc_store, session_context, session_manager, state_manager
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.apis.session_helper import extract_access_token
from facefusion.common_helper import is_linux, is_macos
from facefusion.ffmpeg import spawn_stream
from facefusion.streamer import process_vision_frame
from facefusion.types import Resolution, StreamBuffer, WebSocketStreamMode
def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame:
target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24')
output_vision_frame = process_vision_frame(target_vision_frame)
output_stream_frame = VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24')
output_stream_frame.pts = target_stream_frame.pts
output_stream_frame.time_base = target_stream_frame.time_base
return output_stream_frame
def create_output_track(rtc_connection : RTCPeerConnection, buffer_size : int) -> Tuple[QueuedVideoStreamTrack, RTCRtpSender]:
output_track = QueuedVideoStreamTrack(buffer_size = buffer_size)
sender = rtc_connection.addTrack(output_track)
return output_track, sender
async def process_and_enqueue(target_track : MediaStreamTrack, output_track : QueuedVideoStreamTrack) -> None:
loop = asyncio.get_running_loop()
while True:
try:
target_stream_frame = await target_track.recv()
except MediaStreamError:
pass
output_stream_frame = await loop.run_in_executor(None, process_stream_frame, target_stream_frame) #type:ignore[arg-type]
await output_track.put(output_stream_frame)
def on_video_track(rtc_connection : RTCPeerConnection, output_track : QueuedVideoStreamTrack, target_track : MediaStreamTrack) -> None:
if target_track.kind == 'audio':
rtc_connection.addTrack(target_track)
if target_track.kind == 'video':
asyncio.create_task(process_and_enqueue(target_track, output_track))
from facefusion.types import Resolution, SessionId, VisionFrame, WebSocketStreamMode
def calculate_bitrate(resolution : Resolution) -> int: # TODO : improve the bitrate calculation
@@ -89,8 +59,21 @@ def read_pipe_buffer(pipe_handle : int, size : int) -> Optional[bytes]:
return None
def forward_stream_frame(process : subprocess.Popen[bytes]) -> Iterator[StreamBuffer]:
pipe_handle = process.stdout.fileno()
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
websocket_event = await websocket.receive()
while websocket_event.get('type') == 'websocket.receive':
frame_buffer = websocket_event.get('bytes') or b''
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(vision_frame):
yield vision_frame
websocket_event = await websocket.receive()
def forward_rtc_frames(encoder : subprocess.Popen[bytes], session_id : SessionId) -> None:
pipe_handle = encoder.stdout.fileno()
if is_linux() or is_macos():
os.set_blocking(pipe_handle, True)
@@ -105,6 +88,83 @@ def forward_stream_frame(process : subprocess.Popen[bytes]) -> Iterator[StreamBu
frame_data = read_pipe_buffer(pipe_handle, frame_size)
if frame_data:
yield frame_data
rtc_store.send_rtc_frame(session_id, frame_data)
frame_header = read_pipe_buffer(pipe_handle, 12)
def submit_encoder_frame(encoder : subprocess.Popen[bytes], vision_frame_deque : deque[VisionFrame]) -> None:
output_vision_frame = process_vision_frame(vision_frame_deque[-1])
encoder.stdin.write(cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2RGB).tobytes())
encoder.stdin.flush()
def run_encode_loop(encoder : subprocess.Popen[bytes], vision_frame_deque : deque[VisionFrame]) -> None:
while vision_frame_deque:
submit_encoder_frame(encoder, vision_frame_deque)
encoder.stdin.close()
encoder.wait()
async def handle_image_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
if numpy.any(capture_vision_frame):
output_vision_frame = process_vision_frame(capture_vision_frame)
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
if is_success:
await websocket.send_bytes(output_frame_buffer.tobytes())
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
async def handle_video_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if session_id and source_paths:
output_video_fps = int(state_manager.get_item('output_video_fps') or 30) # TODO: resolve from target video fps
vision_frames = receive_vision_frames(websocket)
vision_frame = await anext(vision_frames, None)
if numpy.any(vision_frame):
resolution = (vision_frame.shape[1], vision_frame.shape[0])
encoder = spawn_stream(resolution, output_video_fps, calculate_bitrate(resolution), calculate_buffer_size(resolution))
vision_frame_deque : deque[VisionFrame] = deque(maxlen = 1)
vision_frame_deque.append(vision_frame)
rtc_store.create_rtc_stream(session_id)
event_loop = asyncio.get_running_loop()
await event_loop.run_in_executor(None, submit_encoder_frame, encoder, vision_frame_deque)
await websocket.send_text('ready')
encode_task = event_loop.run_in_executor(None, run_encode_loop, encoder, vision_frame_deque)
rtc_task = event_loop.run_in_executor(None, forward_rtc_frames, encoder, session_id)
async for vision_frame in vision_frames:
vision_frame_deque.append(vision_frame)
vision_frame_deque.clear()
await asyncio.gather(encode_task, rtc_task)
rtc_store.destroy_rtc_stream(session_id)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
+13 -15
View File
@@ -1,37 +1,35 @@
from typing import List, Optional
from facefusion import rtc
from facefusion.types import RtcPeer, RtcSdpAnswer, RtcSdpOffer, RtcStreamStore
from facefusion.types import RtcPeer, RtcSdpAnswer, RtcSdpOffer, RtcStreamStore, SessionId
RTC_STREAMS : RtcStreamStore = {} # TODO: tie lifetime to session_id so streams are cleaned up on session expiry
RTC_STREAMS : RtcStreamStore = {}
def get_rtc_stream(stream_path : str) -> Optional[List[RtcPeer]]:
return RTC_STREAMS.get(stream_path)
def get_rtc_stream(session_id : SessionId) -> Optional[List[RtcPeer]]:
return RTC_STREAMS.get(session_id)
def create_rtc_stream(stream_path : str) -> None:
RTC_STREAMS[stream_path] = []
def create_rtc_stream(session_id : SessionId) -> None:
RTC_STREAMS[session_id] = []
def destroy_rtc_stream(stream_path : str) -> None:
peers = RTC_STREAMS.pop(stream_path, None)
def destroy_rtc_stream(session_id : SessionId) -> None:
peers = RTC_STREAMS.pop(session_id, None)
if peers:
rtc.delete_peers(peers)
def add_rtc_viewer(stream_path : str, sdp_offer : RtcSdpOffer) -> Optional[RtcSdpAnswer]:
peers = get_rtc_stream(stream_path)
if peers:
return rtc.handle_whep_offer(peers, sdp_offer)
def add_rtc_viewer(session_id : SessionId, sdp_offer : RtcSdpOffer) -> Optional[RtcSdpAnswer]:
if session_id in RTC_STREAMS:
return rtc.handle_whep_offer(RTC_STREAMS.get(session_id), sdp_offer)
return None
def send_rtc_frame(stream_path : str, frame_data : bytes) -> None:
peers = get_rtc_stream(stream_path)
def send_rtc_frame(session_id : SessionId, frame_data : bytes) -> None:
peers = get_rtc_stream(session_id)
if peers:
rtc.send_to_peers(peers, frame_data)
-1
View File
@@ -265,7 +265,6 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet',
WebcamMode = Literal['inline', 'udp', 'v4l2']
StreamMode = Literal['udp', 'v4l2']
WebSocketStreamMode = Literal['image', 'video']
StreamBuffer : TypeAlias = bytes
RtcOfferSet = TypedDict('RtcOfferSet',
{
-1
View File
@@ -7,7 +7,6 @@ nvidia-ml-py==13.590.48
psutil==7.2.2
tqdm==4.67.3
scipy==1.16.3
aiortc @ git+https://github.com/facefusion/aiortc.git@feat/dynamic-bitrate
starlette==0.52.1
uvicorn==0.41.0
websockets==16.0
+51 -14
View File
@@ -1,21 +1,58 @@
from aiortc import RTCPeerConnection, VideoStreamTrack
import ctypes
import os
import threading
import time
from typing import Optional
from facefusion.types import RtcOfferSet
from starlette.testclient import TestClient
from facefusion import rtc
from facefusion.types import RtcSdpOffer
async def create_rtc_offer() -> RtcOfferSet:
rtc_connection = RTCPeerConnection()
rtc_connection.addTrack(VideoStreamTrack())
rtc_offer = await rtc_connection.createOffer()
def create_sdp_offer() -> Optional[RtcSdpOffer]:
rtc_library = rtc.create_static_rtc_library()
peer_connection = rtc.create_peer_connection(disable_auto_negotiation = True)
await rtc_connection.setLocalDescription(rtc_offer)
media_video = os.linesep.join(
[
'm=video 9 UDP/TLS/RTP/SAVPF 96',
'a=rtpmap:96 VP8/90000',
'a=recvonly',
'a=mid:0',
''
]).encode()
media_audio = os.linesep.join(
[
'm=audio 9 UDP/TLS/RTP/SAVPF 111',
'a=rtpmap:111 opus/48000/2',
'a=recvonly',
'a=mid:1',
''
]).encode()
rtc_offer_set : RtcOfferSet =\
{
'sdp': rtc_connection.localDescription.sdp,
'type': rtc_connection.localDescription.type
}
rtc_library.rtcAddTrack(peer_connection, media_video)
rtc_library.rtcAddTrack(peer_connection, media_audio)
rtc_library.rtcSetLocalDescription(peer_connection, b'offer')
await rtc_connection.close()
buffer_size = 16384
buffer_string = ctypes.create_string_buffer(buffer_size)
wait_limit = time.monotonic() + 5
return rtc_offer_set
while time.monotonic() < wait_limit:
if rtc_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0:
sdp = buffer_string.value.decode()
rtc_library.rtcDeletePeerConnection(peer_connection)
return sdp
time.sleep(0.05)
rtc_library.rtcDeletePeerConnection(peer_connection)
return None
def open_websocket_stream(test_client : TestClient, subprotocols : list[str], source_content : bytes, ready_event : threading.Event, stop_event : threading.Event) -> None:
with test_client.websocket_connect('/stream', subprotocols = subprotocols) as websocket:
websocket.send_bytes(source_content)
websocket.receive_text()
ready_event.set()
stop_event.wait()
+19 -9
View File
@@ -1,5 +1,5 @@
import asyncio
import tempfile
import threading
from typing import Iterator
import cv2
@@ -13,7 +13,7 @@ from facefusion.apis.core import create_api
from facefusion.core import common_pre_check, processors_pre_check
from facefusion.download import conditional_download
from .assert_helper import get_test_example_file, get_test_examples_directory
from .stream_helper import create_rtc_offer
from .stream_helper import create_sdp_offer, open_websocket_stream
@pytest.fixture(scope = 'module', autouse = True)
@@ -92,7 +92,8 @@ def test_stream_image(test_client : TestClient) -> None:
with test_client.websocket_connect('/stream', subprotocols =
[
'access_token.' + access_token
'access_token.' + access_token,
'image'
]) as websocket:
websocket.send_bytes(source_content)
output_bytes = websocket.receive_bytes()
@@ -129,12 +130,21 @@ def test_stream_video(test_client : TestClient) -> None:
'Authorization': 'Bearer ' + access_token
})
rtc_offer = asyncio.run(create_rtc_offer())
stream_response = test_client.post('/stream', json = rtc_offer, headers =
ready_event = threading.Event()
stop_event = threading.Event()
stream_thread = threading.Thread(target = open_websocket_stream, args = (test_client, [ 'access_token.' + access_token, 'video' ], source_content, ready_event, stop_event))
stream_thread.start()
ready_event.wait()
sdp_offer = create_sdp_offer()
stream_response = test_client.post('/stream', content = sdp_offer, headers =
{
'Authorization': 'Bearer ' + access_token
'Authorization': 'Bearer ' + access_token,
'Content-Type': 'application/sdp'
})
assert stream_response.status_code == 200
assert stream_response.json().get('type') == 'answer'
assert stream_response.json().get('sdp')
assert stream_response.status_code == 201
assert stream_response.text
stop_event.set()
stream_thread.join()
+3 -25
View File
@@ -1,9 +1,6 @@
import os
import subprocess
from facefusion import ffmpeg_builder
from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, forward_stream_frame, get_websocket_stream_mode, read_pipe_buffer
from facefusion.vision import pack_resolution
from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, get_websocket_stream_mode, read_pipe_buffer
def make_scope(protocol : str) -> dict[str, object]:
@@ -30,7 +27,7 @@ def test_calculate_buffer_size() -> None:
assert calculate_buffer_size((3840, 2160)) == 14000
def test_get_websocket_stream_mode() -> None:
def test_get_stream_mode() -> None:
assert get_websocket_stream_mode(make_scope('image')) == 'image'
assert get_websocket_stream_mode(make_scope('video')) == 'video'
@@ -47,23 +44,4 @@ def test_read_pipe_buffer() -> None:
os.close(read_fd)
def test_forward_frames() -> None:
resolution = (320, 240)
frame_size = resolution[0] * resolution[1] * 3
commands = ffmpeg_builder.run(ffmpeg_builder.chain(
ffmpeg_builder.capture_video(),
ffmpeg_builder.set_media_resolution(pack_resolution(resolution)),
ffmpeg_builder.set_input_fps(30),
ffmpeg_builder.set_input('-'),
ffmpeg_builder.set_video_encoder('libvpx'),
ffmpeg_builder.set_encoder_deadline('realtime'),
ffmpeg_builder.set_stream_quality(400),
ffmpeg_builder.set_muxer('ivf'),
ffmpeg_builder.set_output('-')
))
encoder = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
encoder.stdin.write(bytes(frame_size))
encoder.stdin.close()
for stream_buffer in forward_stream_frame(encoder):
assert 0 < len(stream_buffer) < frame_size
# TODO: add remaining tests