mirror of
https://github.com/facefusion/facefusion.git
synced 2026-05-01 05:47:51 +02:00
Replace aiortc with libdatachannel direct pipeline (#1083)
* fix stdin close error * Refactor stream endpoint, fix encoder thread safety and improve tests * fix and improve test * remove not None * use Enum * use Enum and add todo * remove poll
This commit is contained in:
@@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
from facefusion.apis.endpoints.stream import webrtc_stream, websocket_stream
|
||||
from facefusion.apis.endpoints.stream import post_stream, websocket_stream
|
||||
from facefusion.apis.middlewares.session import create_session_guard
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ def create_api() -> Starlette:
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
|
||||
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/stream', webrtc_stream, methods = ['POST'], middleware = [session_guard]),
|
||||
Route('/stream', post_stream, methods = [ 'POST' ], middleware = [ session_guard ]),
|
||||
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [session_guard])
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
api = Starlette(routes = routes)
|
||||
|
||||
@@ -1,76 +1,33 @@
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_201_CREATED, HTTP_404_NOT_FOUND
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion import session_context, session_manager, state_manager
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion import rtc_store, session_context, session_manager
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import create_output_track, on_video_track
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.apis.stream_helper import get_websocket_stream_mode, handle_image_stream, handle_video_stream
|
||||
|
||||
|
||||
async def websocket_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
stream_mode = get_websocket_stream_mode(websocket.scope)
|
||||
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
if stream_mode == 'image':
|
||||
await handle_image_stream(websocket)
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if source_paths:
|
||||
try:
|
||||
image_buffer = await websocket.receive_bytes()
|
||||
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(target_vision_frame):
|
||||
temp_vision_frame = process_vision_frame(target_vision_frame)
|
||||
is_success, output_vision_frame = cv2.imencode('.jpg', temp_vision_frame)
|
||||
|
||||
if is_success:
|
||||
await websocket.send_bytes(output_vision_frame.tobytes())
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
await websocket.close()
|
||||
if stream_mode == 'video':
|
||||
await handle_video_stream(websocket)
|
||||
|
||||
|
||||
async def webrtc_stream(request : Request) -> Response:
|
||||
async def post_stream(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
|
||||
if session_id:
|
||||
body = await request.json()
|
||||
buffer_size = int(body.get('buffer_size', 30))
|
||||
bitrate_init = int(body.get('bitrate_init', 100000))
|
||||
bitrate_min = int(body.get('bitrate_min', 100000))
|
||||
bitrate_max = int(body.get('bitrate_max', 4000000))
|
||||
sdp_offer = (await request.body()).decode()
|
||||
sdp_answer = rtc_store.add_rtc_viewer(session_id, sdp_offer)
|
||||
|
||||
rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type'))
|
||||
rtc_connection = RTCPeerConnection()
|
||||
if sdp_answer:
|
||||
return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp')
|
||||
|
||||
output_track, sender = create_output_track(rtc_connection, buffer_size)
|
||||
sender.configure_bitrate(bitrate_init, bitrate_min, bitrate_max)
|
||||
|
||||
rtc_connection.on('track', partial(on_video_track, rtc_connection, output_track))
|
||||
|
||||
await rtc_connection.setRemoteDescription(rtc_offer)
|
||||
await rtc_connection.setLocalDescription(await rtc_connection.createAnswer())
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'sdp': rtc_connection.localDescription.sdp,
|
||||
'type': rtc_connection.localDescription.type
|
||||
})
|
||||
|
||||
return Response(status_code = HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
@@ -2,53 +2,23 @@ import asyncio
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Iterator, Optional, Tuple, cast
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Optional, cast
|
||||
|
||||
from aiortc import MediaStreamTrack, QueuedVideoStreamTrack, RTCPeerConnection, RTCRtpSender
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from av import VideoFrame
|
||||
import cv2
|
||||
import numpy
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import Scope
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
|
||||
from facefusion import rtc_store, session_context, session_manager, state_manager
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.common_helper import is_linux, is_macos
|
||||
from facefusion.ffmpeg import spawn_stream
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import Resolution, StreamBuffer, WebSocketStreamMode
|
||||
|
||||
|
||||
def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame:
|
||||
target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24')
|
||||
output_vision_frame = process_vision_frame(target_vision_frame)
|
||||
output_stream_frame = VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24')
|
||||
output_stream_frame.pts = target_stream_frame.pts
|
||||
output_stream_frame.time_base = target_stream_frame.time_base
|
||||
return output_stream_frame
|
||||
|
||||
|
||||
def create_output_track(rtc_connection : RTCPeerConnection, buffer_size : int) -> Tuple[QueuedVideoStreamTrack, RTCRtpSender]:
|
||||
output_track = QueuedVideoStreamTrack(buffer_size = buffer_size)
|
||||
sender = rtc_connection.addTrack(output_track)
|
||||
return output_track, sender
|
||||
|
||||
|
||||
async def process_and_enqueue(target_track : MediaStreamTrack, output_track : QueuedVideoStreamTrack) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
while True:
|
||||
try:
|
||||
target_stream_frame = await target_track.recv()
|
||||
except MediaStreamError:
|
||||
pass
|
||||
|
||||
output_stream_frame = await loop.run_in_executor(None, process_stream_frame, target_stream_frame) #type:ignore[arg-type]
|
||||
await output_track.put(output_stream_frame)
|
||||
|
||||
|
||||
def on_video_track(rtc_connection : RTCPeerConnection, output_track : QueuedVideoStreamTrack, target_track : MediaStreamTrack) -> None:
|
||||
if target_track.kind == 'audio':
|
||||
rtc_connection.addTrack(target_track)
|
||||
|
||||
if target_track.kind == 'video':
|
||||
asyncio.create_task(process_and_enqueue(target_track, output_track))
|
||||
from facefusion.types import Resolution, SessionId, VisionFrame, WebSocketStreamMode
|
||||
|
||||
|
||||
def calculate_bitrate(resolution : Resolution) -> int: # TODO : improve the bitrate calculation
|
||||
@@ -89,8 +59,21 @@ def read_pipe_buffer(pipe_handle : int, size : int) -> Optional[bytes]:
|
||||
return None
|
||||
|
||||
|
||||
def forward_stream_frame(process : subprocess.Popen[bytes]) -> Iterator[StreamBuffer]:
|
||||
pipe_handle = process.stdout.fileno()
|
||||
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
while websocket_event.get('type') == 'websocket.receive':
|
||||
frame_buffer = websocket_event.get('bytes') or b''
|
||||
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
yield vision_frame
|
||||
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
|
||||
def forward_rtc_frames(encoder : subprocess.Popen[bytes], session_id : SessionId) -> None:
|
||||
pipe_handle = encoder.stdout.fileno()
|
||||
|
||||
if is_linux() or is_macos():
|
||||
os.set_blocking(pipe_handle, True)
|
||||
@@ -105,6 +88,83 @@ def forward_stream_frame(process : subprocess.Popen[bytes]) -> Iterator[StreamBu
|
||||
frame_data = read_pipe_buffer(pipe_handle, frame_size)
|
||||
|
||||
if frame_data:
|
||||
yield frame_data
|
||||
rtc_store.send_rtc_frame(session_id, frame_data)
|
||||
|
||||
frame_header = read_pipe_buffer(pipe_handle, 12)
|
||||
|
||||
|
||||
def submit_encoder_frame(encoder : subprocess.Popen[bytes], vision_frame_deque : deque[VisionFrame]) -> None:
|
||||
output_vision_frame = process_vision_frame(vision_frame_deque[-1])
|
||||
encoder.stdin.write(cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2RGB).tobytes())
|
||||
encoder.stdin.flush()
|
||||
|
||||
|
||||
def run_encode_loop(encoder : subprocess.Popen[bytes], vision_frame_deque : deque[VisionFrame]) -> None:
|
||||
while vision_frame_deque:
|
||||
submit_encoder_frame(encoder, vision_frame_deque)
|
||||
|
||||
encoder.stdin.close()
|
||||
encoder.wait()
|
||||
|
||||
|
||||
async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if source_paths:
|
||||
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
|
||||
|
||||
if numpy.any(capture_vision_frame):
|
||||
output_vision_frame = process_vision_frame(capture_vision_frame)
|
||||
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
|
||||
|
||||
if is_success:
|
||||
await websocket.send_bytes(output_frame_buffer.tobytes())
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if session_id and source_paths:
|
||||
output_video_fps = int(state_manager.get_item('output_video_fps') or 30) # TODO: resolve from target video fps
|
||||
vision_frames = receive_vision_frames(websocket)
|
||||
vision_frame = await anext(vision_frames, None)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
encoder = spawn_stream(resolution, output_video_fps, calculate_bitrate(resolution), calculate_buffer_size(resolution))
|
||||
|
||||
vision_frame_deque : deque[VisionFrame] = deque(maxlen = 1)
|
||||
|
||||
vision_frame_deque.append(vision_frame)
|
||||
rtc_store.create_rtc_stream(session_id)
|
||||
|
||||
event_loop = asyncio.get_running_loop()
|
||||
await event_loop.run_in_executor(None, submit_encoder_frame, encoder, vision_frame_deque)
|
||||
await websocket.send_text('ready')
|
||||
encode_task = event_loop.run_in_executor(None, run_encode_loop, encoder, vision_frame_deque)
|
||||
rtc_task = event_loop.run_in_executor(None, forward_rtc_frames, encoder, session_id)
|
||||
|
||||
async for vision_frame in vision_frames:
|
||||
vision_frame_deque.append(vision_frame)
|
||||
|
||||
vision_frame_deque.clear()
|
||||
await asyncio.gather(encode_task, rtc_task)
|
||||
rtc_store.destroy_rtc_stream(session_id)
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
|
||||
+13
-15
@@ -1,37 +1,35 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from facefusion import rtc
|
||||
from facefusion.types import RtcPeer, RtcSdpAnswer, RtcSdpOffer, RtcStreamStore
|
||||
from facefusion.types import RtcPeer, RtcSdpAnswer, RtcSdpOffer, RtcStreamStore, SessionId
|
||||
|
||||
RTC_STREAMS : RtcStreamStore = {} # TODO: tie lifetime to session_id so streams are cleaned up on session expiry
|
||||
RTC_STREAMS : RtcStreamStore = {}
|
||||
|
||||
|
||||
def get_rtc_stream(stream_path : str) -> Optional[List[RtcPeer]]:
|
||||
return RTC_STREAMS.get(stream_path)
|
||||
def get_rtc_stream(session_id : SessionId) -> Optional[List[RtcPeer]]:
|
||||
return RTC_STREAMS.get(session_id)
|
||||
|
||||
|
||||
def create_rtc_stream(stream_path : str) -> None:
|
||||
RTC_STREAMS[stream_path] = []
|
||||
def create_rtc_stream(session_id : SessionId) -> None:
|
||||
RTC_STREAMS[session_id] = []
|
||||
|
||||
|
||||
def destroy_rtc_stream(stream_path : str) -> None:
|
||||
peers = RTC_STREAMS.pop(stream_path, None)
|
||||
def destroy_rtc_stream(session_id : SessionId) -> None:
|
||||
peers = RTC_STREAMS.pop(session_id, None)
|
||||
|
||||
if peers:
|
||||
rtc.delete_peers(peers)
|
||||
|
||||
|
||||
def add_rtc_viewer(stream_path : str, sdp_offer : RtcSdpOffer) -> Optional[RtcSdpAnswer]:
|
||||
peers = get_rtc_stream(stream_path)
|
||||
|
||||
if peers:
|
||||
return rtc.handle_whep_offer(peers, sdp_offer)
|
||||
def add_rtc_viewer(session_id : SessionId, sdp_offer : RtcSdpOffer) -> Optional[RtcSdpAnswer]:
|
||||
if session_id in RTC_STREAMS:
|
||||
return rtc.handle_whep_offer(RTC_STREAMS.get(session_id), sdp_offer)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def send_rtc_frame(stream_path : str, frame_data : bytes) -> None:
|
||||
peers = get_rtc_stream(stream_path)
|
||||
def send_rtc_frame(session_id : SessionId, frame_data : bytes) -> None:
|
||||
peers = get_rtc_stream(session_id)
|
||||
|
||||
if peers:
|
||||
rtc.send_to_peers(peers, frame_data)
|
||||
|
||||
@@ -265,7 +265,6 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet',
|
||||
WebcamMode = Literal['inline', 'udp', 'v4l2']
|
||||
StreamMode = Literal['udp', 'v4l2']
|
||||
WebSocketStreamMode = Literal['image', 'video']
|
||||
StreamBuffer : TypeAlias = bytes
|
||||
|
||||
RtcOfferSet = TypedDict('RtcOfferSet',
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@ nvidia-ml-py==13.590.48
|
||||
psutil==7.2.2
|
||||
tqdm==4.67.3
|
||||
scipy==1.16.3
|
||||
aiortc @ git+https://github.com/facefusion/aiortc.git@feat/dynamic-bitrate
|
||||
starlette==0.52.1
|
||||
uvicorn==0.41.0
|
||||
websockets==16.0
|
||||
|
||||
+51
-14
@@ -1,21 +1,58 @@
|
||||
from aiortc import RTCPeerConnection, VideoStreamTrack
|
||||
import ctypes
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.types import RtcOfferSet
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import rtc
|
||||
from facefusion.types import RtcSdpOffer
|
||||
|
||||
|
||||
async def create_rtc_offer() -> RtcOfferSet:
|
||||
rtc_connection = RTCPeerConnection()
|
||||
rtc_connection.addTrack(VideoStreamTrack())
|
||||
rtc_offer = await rtc_connection.createOffer()
|
||||
def create_sdp_offer() -> Optional[RtcSdpOffer]:
|
||||
rtc_library = rtc.create_static_rtc_library()
|
||||
peer_connection = rtc.create_peer_connection(disable_auto_negotiation = True)
|
||||
|
||||
await rtc_connection.setLocalDescription(rtc_offer)
|
||||
media_video = os.linesep.join(
|
||||
[
|
||||
'm=video 9 UDP/TLS/RTP/SAVPF 96',
|
||||
'a=rtpmap:96 VP8/90000',
|
||||
'a=recvonly',
|
||||
'a=mid:0',
|
||||
''
|
||||
]).encode()
|
||||
media_audio = os.linesep.join(
|
||||
[
|
||||
'm=audio 9 UDP/TLS/RTP/SAVPF 111',
|
||||
'a=rtpmap:111 opus/48000/2',
|
||||
'a=recvonly',
|
||||
'a=mid:1',
|
||||
''
|
||||
]).encode()
|
||||
|
||||
rtc_offer_set : RtcOfferSet =\
|
||||
{
|
||||
'sdp': rtc_connection.localDescription.sdp,
|
||||
'type': rtc_connection.localDescription.type
|
||||
}
|
||||
rtc_library.rtcAddTrack(peer_connection, media_video)
|
||||
rtc_library.rtcAddTrack(peer_connection, media_audio)
|
||||
rtc_library.rtcSetLocalDescription(peer_connection, b'offer')
|
||||
|
||||
await rtc_connection.close()
|
||||
buffer_size = 16384
|
||||
buffer_string = ctypes.create_string_buffer(buffer_size)
|
||||
wait_limit = time.monotonic() + 5
|
||||
|
||||
return rtc_offer_set
|
||||
while time.monotonic() < wait_limit:
|
||||
if rtc_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0:
|
||||
sdp = buffer_string.value.decode()
|
||||
rtc_library.rtcDeletePeerConnection(peer_connection)
|
||||
return sdp
|
||||
time.sleep(0.05)
|
||||
|
||||
rtc_library.rtcDeletePeerConnection(peer_connection)
|
||||
return None
|
||||
|
||||
|
||||
def open_websocket_stream(test_client : TestClient, subprotocols : list[str], source_content : bytes, ready_event : threading.Event, stop_event : threading.Event) -> None:
|
||||
with test_client.websocket_connect('/stream', subprotocols = subprotocols) as websocket:
|
||||
websocket.send_bytes(source_content)
|
||||
websocket.receive_text()
|
||||
ready_event.set()
|
||||
stop_event.wait()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Iterator
|
||||
|
||||
import cv2
|
||||
@@ -13,7 +13,7 @@ from facefusion.apis.core import create_api
|
||||
from facefusion.core import common_pre_check, processors_pre_check
|
||||
from facefusion.download import conditional_download
|
||||
from .assert_helper import get_test_example_file, get_test_examples_directory
|
||||
from .stream_helper import create_rtc_offer
|
||||
from .stream_helper import create_sdp_offer, open_websocket_stream
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
@@ -92,7 +92,8 @@ def test_stream_image(test_client : TestClient) -> None:
|
||||
|
||||
with test_client.websocket_connect('/stream', subprotocols =
|
||||
[
|
||||
'access_token.' + access_token
|
||||
'access_token.' + access_token,
|
||||
'image'
|
||||
]) as websocket:
|
||||
websocket.send_bytes(source_content)
|
||||
output_bytes = websocket.receive_bytes()
|
||||
@@ -129,12 +130,21 @@ def test_stream_video(test_client : TestClient) -> None:
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
rtc_offer = asyncio.run(create_rtc_offer())
|
||||
stream_response = test_client.post('/stream', json = rtc_offer, headers =
|
||||
ready_event = threading.Event()
|
||||
stop_event = threading.Event()
|
||||
stream_thread = threading.Thread(target = open_websocket_stream, args = (test_client, [ 'access_token.' + access_token, 'video' ], source_content, ready_event, stop_event))
|
||||
stream_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
sdp_offer = create_sdp_offer()
|
||||
stream_response = test_client.post('/stream', content = sdp_offer, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
'Authorization': 'Bearer ' + access_token,
|
||||
'Content-Type': 'application/sdp'
|
||||
})
|
||||
|
||||
assert stream_response.status_code == 200
|
||||
assert stream_response.json().get('type') == 'answer'
|
||||
assert stream_response.json().get('sdp')
|
||||
assert stream_response.status_code == 201
|
||||
assert stream_response.text
|
||||
|
||||
stop_event.set()
|
||||
stream_thread.join()
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from facefusion import ffmpeg_builder
|
||||
from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, forward_stream_frame, get_websocket_stream_mode, read_pipe_buffer
|
||||
from facefusion.vision import pack_resolution
|
||||
from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, get_websocket_stream_mode, read_pipe_buffer
|
||||
|
||||
|
||||
def make_scope(protocol : str) -> dict[str, object]:
|
||||
@@ -30,7 +27,7 @@ def test_calculate_buffer_size() -> None:
|
||||
assert calculate_buffer_size((3840, 2160)) == 14000
|
||||
|
||||
|
||||
def test_get_websocket_stream_mode() -> None:
|
||||
def test_get_stream_mode() -> None:
|
||||
assert get_websocket_stream_mode(make_scope('image')) == 'image'
|
||||
assert get_websocket_stream_mode(make_scope('video')) == 'video'
|
||||
|
||||
@@ -47,23 +44,4 @@ def test_read_pipe_buffer() -> None:
|
||||
os.close(read_fd)
|
||||
|
||||
|
||||
def test_forward_frames() -> None:
|
||||
resolution = (320, 240)
|
||||
frame_size = resolution[0] * resolution[1] * 3
|
||||
commands = ffmpeg_builder.run(ffmpeg_builder.chain(
|
||||
ffmpeg_builder.capture_video(),
|
||||
ffmpeg_builder.set_media_resolution(pack_resolution(resolution)),
|
||||
ffmpeg_builder.set_input_fps(30),
|
||||
ffmpeg_builder.set_input('-'),
|
||||
ffmpeg_builder.set_video_encoder('libvpx'),
|
||||
ffmpeg_builder.set_encoder_deadline('realtime'),
|
||||
ffmpeg_builder.set_stream_quality(400),
|
||||
ffmpeg_builder.set_muxer('ivf'),
|
||||
ffmpeg_builder.set_output('-')
|
||||
))
|
||||
encoder = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
|
||||
encoder.stdin.write(bytes(frame_size))
|
||||
encoder.stdin.close()
|
||||
|
||||
for stream_buffer in forward_stream_frame(encoder):
|
||||
assert 0 < len(stream_buffer) < frame_size
|
||||
# TODO: add remaining tests
|
||||
|
||||
Reference in New Issue
Block a user