mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 17:36:16 +02:00
experiment to run video via websocket
This commit is contained in:
@@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
from facefusion.apis.endpoints.stream import webrtc_stream, websocket_stream
|
||||
from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_live
|
||||
from facefusion.apis.middlewares.session import create_session_guard
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ def create_api() -> Starlette:
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
|
||||
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
|
||||
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
api = Starlette(routes = routes)
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
from functools import partial
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion import session_context, session_manager, state_manager
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import on_video_track
|
||||
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_stream_encoder, create_stream_encoder, encode_stream_frame, process_stream_frame, read_stream_output
|
||||
from facefusion.streamer import process_vision_frame
|
||||
|
||||
|
||||
@@ -44,25 +41,46 @@ async def websocket_stream(websocket : WebSocket) -> None:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def webrtc_stream(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
async def websocket_stream_live(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
|
||||
if session_id:
|
||||
body = await request.json()
|
||||
rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type'))
|
||||
rtc_connection = RTCPeerConnection()
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
rtc_connection.on('track', partial(on_video_track, rtc_connection))
|
||||
if source_paths:
|
||||
encoder = None
|
||||
reader_thread = None
|
||||
output_chunks = []
|
||||
lock = threading.Lock()
|
||||
|
||||
await rtc_connection.setRemoteDescription(rtc_offer)
|
||||
await rtc_connection.setLocalDescription(await rtc_connection.createAnswer())
|
||||
try:
|
||||
while True:
|
||||
image_buffer = await websocket.receive_bytes()
|
||||
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'sdp': rtc_connection.localDescription.sdp,
|
||||
'type': rtc_connection.localDescription.type
|
||||
})
|
||||
if numpy.any(target_vision_frame):
|
||||
temp_vision_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_vision_frame)
|
||||
|
||||
return Response(status_code = HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
if not encoder:
|
||||
height, width = temp_vision_frame.shape[:2]
|
||||
encoder = create_stream_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
|
||||
reader_thread = threading.Thread(target = read_stream_output, args = (encoder, output_chunks, lock), daemon = True)
|
||||
reader_thread.start()
|
||||
|
||||
encoded_bytes = encode_stream_frame(encoder, temp_vision_frame, output_chunks, lock)
|
||||
|
||||
if encoded_bytes:
|
||||
await websocket.send_bytes(encoded_bytes)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if encoder:
|
||||
close_stream_encoder(encoder)
|
||||
return
|
||||
|
||||
await websocket.close()
|
||||
|
||||
@@ -1,36 +1,67 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, VideoStreamTrack
|
||||
from av import VideoFrame
|
||||
import cv2
|
||||
|
||||
from facefusion import ffmpeg_builder
|
||||
from facefusion.ffmpeg import open_ffmpeg
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import VisionFrame
|
||||
|
||||
STREAM_FPS : int = 30
|
||||
STREAM_QUALITY : int = 80
|
||||
|
||||
|
||||
def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame:
|
||||
target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24')
|
||||
output_vision_frame = process_vision_frame(target_vision_frame)
|
||||
return VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24')
|
||||
def create_stream_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
|
||||
commands = ffmpeg_builder.chain(
|
||||
ffmpeg_builder.capture_video(),
|
||||
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
|
||||
ffmpeg_builder.set_input_fps(stream_fps),
|
||||
ffmpeg_builder.set_input('-'),
|
||||
ffmpeg_builder.set_video_encoder('libx264'),
|
||||
ffmpeg_builder.set_video_quality('libx264', stream_quality),
|
||||
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
|
||||
[ '-tune', 'zerolatency' ],
|
||||
[ '-maxrate', '4000k' ],
|
||||
[ '-bufsize', '8000k' ],
|
||||
[ '-g', str(stream_fps) ],
|
||||
[ '-f', 'mp4' ],
|
||||
[ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ],
|
||||
ffmpeg_builder.set_output('-')
|
||||
)
|
||||
return open_ffmpeg(commands)
|
||||
|
||||
|
||||
def create_output_track(target_track : MediaStreamTrack) -> VideoStreamTrack:
|
||||
output_track = VideoStreamTrack()
|
||||
def read_stream_output(process : subprocess.Popen[bytes], output_chunks : list, lock : threading.Lock) -> None:
|
||||
while True:
|
||||
chunk = process.stdout.read(4096)
|
||||
|
||||
async def read_stream_frame() -> VideoFrame:
|
||||
target_stream_frame = cast(VideoFrame, await target_track.recv())
|
||||
output_stream_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_stream_frame)
|
||||
output_stream_frame.pts = target_stream_frame.pts
|
||||
output_stream_frame.time_base = target_stream_frame.time_base
|
||||
return output_stream_frame
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
output_track.recv = read_stream_frame
|
||||
return output_track
|
||||
with lock:
|
||||
output_chunks.append(chunk)
|
||||
|
||||
|
||||
def on_video_track(rtc_connection : RTCPeerConnection, target_track : MediaStreamTrack) -> None:
|
||||
if target_track.kind == 'audio':
|
||||
rtc_connection.addTrack(target_track)
|
||||
def encode_stream_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame, output_chunks : list, lock : threading.Lock) -> Optional[bytes]:
|
||||
raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
|
||||
process.stdin.write(raw_bytes)
|
||||
process.stdin.flush()
|
||||
|
||||
if target_track.kind == 'video':
|
||||
output_track = create_output_track(target_track)
|
||||
rtc_connection.addTrack(output_track)
|
||||
with lock:
|
||||
if output_chunks:
|
||||
encoded_bytes = b''.join(output_chunks)
|
||||
output_chunks.clear()
|
||||
return encoded_bytes
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def close_stream_encoder(process : subprocess.Popen[bytes]) -> None:
|
||||
process.stdin.close()
|
||||
process.wait()
|
||||
|
||||
|
||||
def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame:
|
||||
return process_vision_frame(vision_frame)
|
||||
|
||||
@@ -259,12 +259,6 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet',
|
||||
WebcamMode = Literal['inline', 'udp', 'v4l2']
|
||||
StreamMode = Literal['udp', 'v4l2']
|
||||
|
||||
RtcOfferSet = TypedDict('RtcOfferSet',
|
||||
{
|
||||
'sdp': str,
|
||||
'type': str
|
||||
})
|
||||
|
||||
ModelOptions : TypeAlias = Dict[str, Any]
|
||||
ModelSet : TypeAlias = Dict[str, ModelOptions]
|
||||
ModelInitializer : TypeAlias = NDArray[Any]
|
||||
|
||||
@@ -7,7 +7,6 @@ nvidia-ml-py==13.590.48
|
||||
psutil==7.2.2
|
||||
tqdm==4.67.3
|
||||
scipy==1.16.3
|
||||
aiortc==1.14.0
|
||||
starlette==0.52.1
|
||||
uvicorn==0.41.0
|
||||
websockets==16.0
|
||||
|
||||
+8
-18
@@ -1,21 +1,11 @@
|
||||
from aiortc import RTCPeerConnection, VideoStreamTrack
|
||||
|
||||
from facefusion.types import RtcOfferSet
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
|
||||
async def create_rtc_offer() -> RtcOfferSet:
|
||||
rtc_connection = RTCPeerConnection()
|
||||
rtc_connection.addTrack(VideoStreamTrack())
|
||||
rtc_offer = await rtc_connection.createOffer()
|
||||
def create_test_frame_bytes(width : int, height : int) -> bytes:
|
||||
vision_frame = numpy.zeros((height, width, 3), dtype = numpy.uint8)
|
||||
is_success, image_buffer = cv2.imencode('.jpg', vision_frame)
|
||||
|
||||
await rtc_connection.setLocalDescription(rtc_offer)
|
||||
|
||||
rtc_offer_set : RtcOfferSet =\
|
||||
{
|
||||
'sdp': rtc_connection.localDescription.sdp,
|
||||
'type': rtc_connection.localDescription.type
|
||||
}
|
||||
|
||||
await rtc_connection.close()
|
||||
|
||||
return rtc_offer_set
|
||||
if is_success:
|
||||
return image_buffer.tobytes()
|
||||
return b''
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
from typing import Iterator
|
||||
|
||||
@@ -13,7 +12,6 @@ from facefusion.apis.core import create_api
|
||||
from facefusion.core import common_pre_check, processors_pre_check
|
||||
from facefusion.download import conditional_download
|
||||
from .assert_helper import get_test_example_file, get_test_examples_directory
|
||||
from .stream_helper import create_rtc_offer
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
@@ -101,7 +99,7 @@ def test_stream_image(test_client : TestClient) -> None:
|
||||
assert output_vision_frame.shape == (1024, 1024, 3)
|
||||
|
||||
|
||||
def test_stream_video(test_client : TestClient) -> None:
|
||||
def test_stream_live(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
@@ -129,12 +127,11 @@ def test_stream_video(test_client : TestClient) -> None:
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
rtc_offer = asyncio.run(create_rtc_offer())
|
||||
stream_response = test_client.post('/stream', json = rtc_offer, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
with test_client.websocket_connect('/stream/live', subprotocols =
|
||||
[
|
||||
'access_token.' + access_token
|
||||
]) as websocket:
|
||||
websocket.send_bytes(source_content)
|
||||
output_bytes = websocket.receive_bytes()
|
||||
|
||||
assert stream_response.status_code == 200
|
||||
assert stream_response.json().get('type') == 'answer'
|
||||
assert stream_response.json().get('sdp')
|
||||
assert len(output_bytes) > 0
|
||||
|
||||
Reference in New Issue
Block a user