experiment to run video via websocket

This commit is contained in:
henryruhs
2026-03-20 10:49:32 +01:00
parent 3acb71c44e
commit 87c2eebb2d
7 changed files with 113 additions and 84 deletions
+3 -3
View File
@@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
from facefusion.apis.endpoints.stream import webrtc_stream, websocket_stream
from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_live
from facefusion.apis.middlewares.session import create_session_guard
@@ -29,10 +29,10 @@ def create_api() -> Starlette:
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]),
WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ])
]
api = Starlette(routes = routes)
+39 -21
View File
@@ -1,17 +1,14 @@
from functools import partial
import asyncio
import threading
import cv2
import numpy
from aiortc import RTCPeerConnection, RTCSessionDescription
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from starlette.websockets import WebSocket
from facefusion import session_context, session_manager, state_manager
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.apis.session_helper import extract_access_token
from facefusion.apis.stream_helper import on_video_track
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_stream_encoder, create_stream_encoder, encode_stream_frame, process_stream_frame, read_stream_output
from facefusion.streamer import process_vision_frame
@@ -44,25 +41,46 @@ async def websocket_stream(websocket : WebSocket) -> None:
await websocket.close()
async def webrtc_stream(request : Request) -> Response:
access_token = extract_access_token(request.scope)
async def websocket_stream_live(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
if session_id:
body = await request.json()
rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type'))
rtc_connection = RTCPeerConnection()
await websocket.accept(subprotocol = subprotocol)
rtc_connection.on('track', partial(on_video_track, rtc_connection))
if source_paths:
encoder = None
reader_thread = None
output_chunks = []
lock = threading.Lock()
await rtc_connection.setRemoteDescription(rtc_offer)
await rtc_connection.setLocalDescription(await rtc_connection.createAnswer())
try:
while True:
image_buffer = await websocket.receive_bytes()
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
return JSONResponse(
{
'sdp': rtc_connection.localDescription.sdp,
'type': rtc_connection.localDescription.type
})
if numpy.any(target_vision_frame):
temp_vision_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_vision_frame)
return Response(status_code = HTTP_500_INTERNAL_SERVER_ERROR)
if not encoder:
height, width = temp_vision_frame.shape[:2]
encoder = create_stream_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
reader_thread = threading.Thread(target = read_stream_output, args = (encoder, output_chunks, lock), daemon = True)
reader_thread.start()
encoded_bytes = encode_stream_frame(encoder, temp_vision_frame, output_chunks, lock)
if encoded_bytes:
await websocket.send_bytes(encoded_bytes)
except Exception:
pass
if encoder:
close_stream_encoder(encoder)
return
await websocket.close()
+55 -24
View File
@@ -1,36 +1,67 @@
import asyncio
from typing import cast
import subprocess
import threading
from typing import Optional
from aiortc import MediaStreamTrack, RTCPeerConnection, VideoStreamTrack
from av import VideoFrame
import cv2
from facefusion import ffmpeg_builder
from facefusion.ffmpeg import open_ffmpeg
from facefusion.streamer import process_vision_frame
from facefusion.types import VisionFrame
STREAM_FPS : int = 30
STREAM_QUALITY : int = 80
def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame:
target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24')
output_vision_frame = process_vision_frame(target_vision_frame)
return VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24')
def create_stream_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
commands = ffmpeg_builder.chain(
ffmpeg_builder.capture_video(),
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
ffmpeg_builder.set_input_fps(stream_fps),
ffmpeg_builder.set_input('-'),
ffmpeg_builder.set_video_encoder('libx264'),
ffmpeg_builder.set_video_quality('libx264', stream_quality),
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
[ '-tune', 'zerolatency' ],
[ '-maxrate', '4000k' ],
[ '-bufsize', '8000k' ],
[ '-g', str(stream_fps) ],
[ '-f', 'mp4' ],
[ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ],
ffmpeg_builder.set_output('-')
)
return open_ffmpeg(commands)
def create_output_track(target_track : MediaStreamTrack) -> VideoStreamTrack:
output_track = VideoStreamTrack()
def read_stream_output(process : subprocess.Popen[bytes], output_chunks : list, lock : threading.Lock) -> None:
while True:
chunk = process.stdout.read(4096)
async def read_stream_frame() -> VideoFrame:
target_stream_frame = cast(VideoFrame, await target_track.recv())
output_stream_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_stream_frame)
output_stream_frame.pts = target_stream_frame.pts
output_stream_frame.time_base = target_stream_frame.time_base
return output_stream_frame
if not chunk:
break
output_track.recv = read_stream_frame
return output_track
with lock:
output_chunks.append(chunk)
def on_video_track(rtc_connection : RTCPeerConnection, target_track : MediaStreamTrack) -> None:
if target_track.kind == 'audio':
rtc_connection.addTrack(target_track)
def encode_stream_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame, output_chunks : list, lock : threading.Lock) -> Optional[bytes]:
raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
process.stdin.write(raw_bytes)
process.stdin.flush()
if target_track.kind == 'video':
output_track = create_output_track(target_track)
rtc_connection.addTrack(output_track)
with lock:
if output_chunks:
encoded_bytes = b''.join(output_chunks)
output_chunks.clear()
return encoded_bytes
return None
def close_stream_encoder(process : subprocess.Popen[bytes]) -> None:
process.stdin.close()
process.wait()
def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame:
return process_vision_frame(vision_frame)
-6
View File
@@ -259,12 +259,6 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet',
WebcamMode = Literal['inline', 'udp', 'v4l2']
StreamMode = Literal['udp', 'v4l2']
RtcOfferSet = TypedDict('RtcOfferSet',
{
'sdp': str,
'type': str
})
ModelOptions : TypeAlias = Dict[str, Any]
ModelSet : TypeAlias = Dict[str, ModelOptions]
ModelInitializer : TypeAlias = NDArray[Any]
-1
View File
@@ -7,7 +7,6 @@ nvidia-ml-py==13.590.48
psutil==7.2.2
tqdm==4.67.3
scipy==1.16.3
aiortc==1.14.0
starlette==0.52.1
uvicorn==0.41.0
websockets==16.0
+8 -18
View File
@@ -1,21 +1,11 @@
from aiortc import RTCPeerConnection, VideoStreamTrack
from facefusion.types import RtcOfferSet
import cv2
import numpy
async def create_rtc_offer() -> RtcOfferSet:
rtc_connection = RTCPeerConnection()
rtc_connection.addTrack(VideoStreamTrack())
rtc_offer = await rtc_connection.createOffer()
def create_test_frame_bytes(width : int, height : int) -> bytes:
vision_frame = numpy.zeros((height, width, 3), dtype = numpy.uint8)
is_success, image_buffer = cv2.imencode('.jpg', vision_frame)
await rtc_connection.setLocalDescription(rtc_offer)
rtc_offer_set : RtcOfferSet =\
{
'sdp': rtc_connection.localDescription.sdp,
'type': rtc_connection.localDescription.type
}
await rtc_connection.close()
return rtc_offer_set
if is_success:
return image_buffer.tobytes()
return b''
+8 -11
View File
@@ -1,4 +1,3 @@
import asyncio
import tempfile
from typing import Iterator
@@ -13,7 +12,6 @@ from facefusion.apis.core import create_api
from facefusion.core import common_pre_check, processors_pre_check
from facefusion.download import conditional_download
from .assert_helper import get_test_example_file, get_test_examples_directory
from .stream_helper import create_rtc_offer
@pytest.fixture(scope = 'module', autouse = True)
@@ -101,7 +99,7 @@ def test_stream_image(test_client : TestClient) -> None:
assert output_vision_frame.shape == (1024, 1024, 3)
def test_stream_video(test_client : TestClient) -> None:
def test_stream_live(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
@@ -129,12 +127,11 @@ def test_stream_video(test_client : TestClient) -> None:
'Authorization': 'Bearer ' + access_token
})
rtc_offer = asyncio.run(create_rtc_offer())
stream_response = test_client.post('/stream', json = rtc_offer, headers =
{
'Authorization': 'Bearer ' + access_token
})
with test_client.websocket_connect('/stream/live', subprotocols =
[
'access_token.' + access_token
]) as websocket:
websocket.send_bytes(source_content)
output_bytes = websocket.receive_bytes()
assert stream_response.status_code == 200
assert stream_response.json().get('type') == 'answer'
assert stream_response.json().get('sdp')
assert len(output_bytes) > 0