mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-02 19:01:35 +02:00
Refactor stream_helper: queue-based audio/video loops with unified threading (#1116)
* rearrange methods following the flow * add test_stream_helper.py * fix lint * fix lint * refactor audio flow to match video by replacing dequeue with queue * remove unused keyframe interval * remove try block * remove while True * simplify run_aom_encode_loop and run_vp8_encode_loop * cleanup names * simplify run_opus_encode_loop * move opus_encoder creation to run_opus_encode_loop * add todos * fix lint * update todos and tests
This commit is contained in:
+159
-146
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
import queue # TODO: try deque
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Optional, Tuple, cast, get_args
|
||||
|
||||
@@ -17,29 +17,97 @@ from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
|
||||
|
||||
|
||||
async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]:
|
||||
websocket_event = await websocket.receive()
|
||||
# TODO: refine this method
|
||||
async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
stream_codec : VideoCodec = 'av1'
|
||||
|
||||
while websocket_event.get('type') == 'websocket.receive':
|
||||
frame_buffer = websocket_event.get('bytes') or bytes()
|
||||
if websocket.query_params.get('codec') in get_args(VideoCodec):
|
||||
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
|
||||
|
||||
if len(frame_buffer) > 1:
|
||||
yield frame_buffer[0], frame_buffer[1:]
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
websocket_event = await websocket.receive()
|
||||
if session_id:
|
||||
stream_frames = receive_stream_frames(websocket)
|
||||
first_vision_frame : Optional[VisionFrame] = None
|
||||
|
||||
async for first_frame_type, first_frame_buffer in stream_frames:
|
||||
if first_frame_type == 1:
|
||||
first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
break
|
||||
|
||||
if numpy.any(first_vision_frame):
|
||||
resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0])
|
||||
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
|
||||
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
|
||||
audio_temp = numpy.array([], dtype = numpy.float32)
|
||||
|
||||
vision_frame_queue.put(first_vision_frame)
|
||||
rtc_store.create_rtc_peers(session_id)
|
||||
|
||||
event_loop = asyncio.get_running_loop()
|
||||
|
||||
if stream_codec == 'av1':
|
||||
video_encode_task = event_loop.run_in_executor(None, run_aom_encode_loop, vision_frame_queue, session_id, resolution)
|
||||
if stream_codec == 'vp8':
|
||||
video_encode_task = event_loop.run_in_executor(None, run_vp8_encode_loop, vision_frame_queue, session_id, resolution)
|
||||
|
||||
audio_encode_task = event_loop.run_in_executor(None, run_opus_encode_loop, audio_chunk_queue, session_id)
|
||||
await websocket.send_text('ready')
|
||||
|
||||
async for frame_type, frame_buffer in stream_frames:
|
||||
if frame_type == 1:
|
||||
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
if vision_frame_queue.qsize():
|
||||
vision_frame_queue.get_nowait()
|
||||
vision_frame_queue.put(vision_frame)
|
||||
|
||||
if frame_type == 2:
|
||||
audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ])
|
||||
|
||||
while len(audio_temp) >= 1920:
|
||||
audio_chunk_queue.put(audio_temp[:1920].tobytes())
|
||||
audio_temp = audio_temp[1920:]
|
||||
|
||||
vision_frame_queue.put(None)
|
||||
audio_chunk_queue.put(None)
|
||||
|
||||
await video_encode_task
|
||||
await audio_encode_task
|
||||
|
||||
rtc_store.destroy_rtc_peers(session_id)
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
|
||||
websocket_event = await websocket.receive()
|
||||
# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream
|
||||
async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
|
||||
while websocket_event.get('type') == 'websocket.receive':
|
||||
frame_buffer = websocket_event.get('bytes') or bytes()
|
||||
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
yield vision_frame
|
||||
if source_paths:
|
||||
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
|
||||
|
||||
websocket_event = await websocket.receive()
|
||||
if numpy.any(capture_vision_frame):
|
||||
output_vision_frame = process_vision_frame(capture_vision_frame)
|
||||
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
|
||||
|
||||
if is_success:
|
||||
await websocket.send_bytes(output_frame_buffer.tobytes())
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
# TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop
|
||||
@@ -76,170 +144,115 @@ def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[Sdp
|
||||
return None
|
||||
|
||||
|
||||
def run_aom_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
|
||||
aom_encoder = create_aom_encoder(initial_resolution, 4500, 8, 10)
|
||||
current_resolution = initial_resolution
|
||||
pts = 0
|
||||
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
|
||||
def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
|
||||
aom_encoder = create_aom_encoder(frame_resolution, 4500, 8, 10)
|
||||
temp_resolution = frame_resolution
|
||||
timestamp = 0
|
||||
|
||||
while vision_frame_deque:
|
||||
vision_frame = vision_frame_deque[-1]
|
||||
output_frame = process_vision_frame(vision_frame)
|
||||
frame_resolution = (output_frame.shape[1], output_frame.shape[0])
|
||||
vision_frame = vision_frame_queue.get()
|
||||
|
||||
if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]:
|
||||
if aom_encoder:
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
while numpy.any(vision_frame) and aom_encoder:
|
||||
output_vision_frame = process_vision_frame(vision_frame)
|
||||
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
|
||||
|
||||
current_resolution = frame_resolution
|
||||
aom_encoder = create_aom_encoder(current_resolution, 4500, 8, 10)
|
||||
pts = 0
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
|
||||
if aom_encoder:
|
||||
yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420)
|
||||
frame_buffer = encode_aom_buffer(aom_encoder, yuv_frame.tobytes(), frame_resolution, pts)
|
||||
if output_frame_buffer and rtc_peers:
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer)
|
||||
|
||||
if frame_buffer:
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
continue
|
||||
|
||||
if rtc_peers:
|
||||
rtc.send_video_to_peers(rtc_peers, frame_buffer)
|
||||
|
||||
pts += 1
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
temp_resolution = output_resolution
|
||||
aom_encoder = create_aom_encoder(temp_resolution, 4500, 8, 10)
|
||||
timestamp = 0
|
||||
|
||||
if aom_encoder:
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
|
||||
|
||||
def run_vp8_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
|
||||
vpx_encoder = create_vpx_encoder(initial_resolution, 4500, 8, 16)
|
||||
current_resolution = initial_resolution
|
||||
pts = 0
|
||||
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
|
||||
def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
|
||||
vpx_encoder = create_vpx_encoder(frame_resolution, 4500, 8, 16)
|
||||
temp_resolution = frame_resolution
|
||||
timestamp = 0
|
||||
|
||||
while vision_frame_deque:
|
||||
vision_frame = vision_frame_deque[-1]
|
||||
output_frame = process_vision_frame(vision_frame)
|
||||
frame_resolution = (output_frame.shape[1], output_frame.shape[0])
|
||||
vision_frame = vision_frame_queue.get()
|
||||
|
||||
if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]:
|
||||
if vpx_encoder:
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
while numpy.any(vision_frame) and vpx_encoder:
|
||||
output_vision_frame = process_vision_frame(vision_frame)
|
||||
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
|
||||
|
||||
current_resolution = frame_resolution
|
||||
vpx_encoder = create_vpx_encoder(current_resolution, 4500, 8, 16)
|
||||
pts = 0
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
|
||||
if vpx_encoder:
|
||||
yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420)
|
||||
frame_buffer = encode_vpx_buffer(vpx_encoder, yuv_frame.tobytes(), frame_resolution, pts)
|
||||
if output_frame_buffer and rtc_peers:
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer)
|
||||
|
||||
if frame_buffer:
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
continue
|
||||
|
||||
if rtc_peers:
|
||||
rtc.send_video_to_peers(rtc_peers, frame_buffer)
|
||||
|
||||
pts += 1
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
temp_resolution = output_resolution
|
||||
vpx_encoder = create_vpx_encoder(temp_resolution, 4500, 8, 16)
|
||||
timestamp = 0
|
||||
|
||||
if vpx_encoder:
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
|
||||
|
||||
# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream
|
||||
async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
# TODO: switch to loop_encode_audio or encode_audio_loop ... pass audio_codec to follow standards
|
||||
def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
|
||||
opus_encoder = create_opus_encoder(48000, 2)
|
||||
audio_timestamp = 0
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
audio_chunk = audio_chunk_queue.get()
|
||||
|
||||
if source_paths:
|
||||
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
|
||||
while audio_chunk: # TODO: improve this condition with b''
|
||||
audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960)
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
|
||||
if numpy.any(capture_vision_frame):
|
||||
output_vision_frame = process_vision_frame(capture_vision_frame)
|
||||
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
|
||||
if audio_buffer and rtc_peers:
|
||||
rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp)
|
||||
|
||||
if is_success:
|
||||
await websocket.send_bytes(output_frame_buffer.tobytes())
|
||||
audio_timestamp += 960
|
||||
audio_chunk = audio_chunk_queue.get()
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
if opus_encoder:
|
||||
destroy_opus_encoder(opus_encoder)
|
||||
|
||||
|
||||
async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
stream_codec : VideoCodec = 'av1'
|
||||
# TODO: needs refinement
|
||||
async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]:
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
if websocket.query_params.get('codec') in get_args(VideoCodec):
|
||||
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
|
||||
while websocket_event.get('type') == 'websocket.receive':
|
||||
frame_buffer = websocket_event.get('bytes') or bytes()
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
if len(frame_buffer) > 1:
|
||||
yield frame_buffer[0], frame_buffer[1:]
|
||||
|
||||
if session_id:
|
||||
stream_frames = receive_stream_frames(websocket)
|
||||
first_vision_frame = None
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
# TODO: audio frames may arrive before video due to ScriptProcessor firing faster than canvas toBlob
|
||||
async for first_frame_type, first_frame_buffer in stream_frames:
|
||||
if first_frame_type == 1:
|
||||
first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
break
|
||||
|
||||
if numpy.any(first_vision_frame):
|
||||
resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0])
|
||||
keyframe_interval = int(state_manager.get_item('output_video_fps') or 30) # TODO: remove hardcoded via stream_video_fps
|
||||
vision_frame_deque : deque[VisionFrame] = deque(maxlen = 1)
|
||||
opus_encoder = create_opus_encoder(48000, 2) # TODO: guard against opus_encoder being None
|
||||
audio_temp = numpy.array([], dtype = numpy.float32)
|
||||
audio_timestamp = 0
|
||||
# TODO: needs refinement, does it receive frames or a buffer?
|
||||
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
vision_frame_deque.append(first_vision_frame)
|
||||
rtc_store.create_rtc_peers(session_id)
|
||||
while websocket_event.get('type') == 'websocket.receive':
|
||||
frame_buffer = websocket_event.get('bytes') or bytes()
|
||||
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
event_loop = asyncio.get_running_loop()
|
||||
encode_loop = run_aom_encode_loop
|
||||
if numpy.any(vision_frame):
|
||||
yield vision_frame
|
||||
|
||||
if stream_codec == 'vp8':
|
||||
encode_loop = run_vp8_encode_loop
|
||||
|
||||
video_encode_task = event_loop.run_in_executor(None, encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval)
|
||||
await websocket.send_text('ready')
|
||||
|
||||
async for frame_type, frame_buffer in stream_frames:
|
||||
if frame_type == 1:
|
||||
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
vision_frame_deque.append(vision_frame)
|
||||
|
||||
if frame_type == 2:
|
||||
audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ])
|
||||
|
||||
while len(audio_temp) >= 1920:
|
||||
audio_chunk = audio_temp[:1920]
|
||||
audio_temp = audio_temp[1920:]
|
||||
audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk.tobytes(), 960)
|
||||
|
||||
if audio_buffer:
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
|
||||
if rtc_peers:
|
||||
rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp)
|
||||
|
||||
audio_timestamp += 960
|
||||
|
||||
vision_frame_deque.clear()
|
||||
await video_encode_task
|
||||
|
||||
if opus_encoder:
|
||||
destroy_opus_encoder(opus_encoder)
|
||||
|
||||
rtc_store.destroy_rtc_peers(session_id)
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
import asyncio
|
||||
import queue
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import pytest
|
||||
from numpy.typing import NDArray
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop
|
||||
from facefusion.hash_helper import create_hash
|
||||
from facefusion.types import VisionFrame
|
||||
|
||||
|
||||
def _make_handler_websocket(events : list[Any]) -> MagicMock:
|
||||
mock = MagicMock()
|
||||
mock.scope = {}
|
||||
mock.client_state = WebSocketState.CONNECTED
|
||||
mock.accept = AsyncMock()
|
||||
mock.send_text = AsyncMock()
|
||||
mock.close = AsyncMock()
|
||||
mock.receive = AsyncMock(side_effect = events)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_video_packet(frame : NDArray[Any]) -> bytes:
|
||||
_, encoded = cv2.imencode('.jpg', frame)
|
||||
return b'\x01' + encoded.tobytes()
|
||||
|
||||
|
||||
def _make_audio_packet(samples : NDArray[Any]) -> bytes:
|
||||
return b'\x02' + samples.tobytes()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
|
||||
def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
|
||||
black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8)
|
||||
prefix = 'facefusion.apis.stream_helper.'
|
||||
|
||||
create_name = prefix + 'create_aom_encoder'
|
||||
encode_name = prefix + 'encode_aom_buffer'
|
||||
destroy_name = prefix + 'destroy_aom_encoder'
|
||||
run_loop = run_aom_encode_loop
|
||||
|
||||
if video_codec == 'vp8':
|
||||
create_name = prefix + 'create_vpx_encoder'
|
||||
encode_name = prefix + 'encode_vpx_buffer'
|
||||
destroy_name = prefix + 'destroy_vpx_encoder'
|
||||
run_loop = run_vp8_encode_loop
|
||||
|
||||
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
||||
patch(create_name, return_value = MagicMock()), \
|
||||
patch(encode_name, return_value = b'encoded'), \
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
||||
patch(create_name, return_value = MagicMock()), \
|
||||
patch(encode_name, return_value = b'encoded'), \
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
assert mock_rtc.send_video_to_peers.call_count == 3
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(black_frame)
|
||||
with patch(create_name, return_value = MagicMock()), \
|
||||
patch(destroy_name) as mock_destroy, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
||||
patch(create_name, return_value = MagicMock()), \
|
||||
patch(encode_name, return_value = b''), \
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store'), \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(small_frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = large_frame), \
|
||||
patch(create_name, return_value = MagicMock()) as mock_create, \
|
||||
patch(encode_name, return_value = b'encoded'), \
|
||||
patch(destroy_name) as mock_destroy, \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
assert mock_create.call_count == 2
|
||||
assert mock_destroy.call_count == 2
|
||||
mock_rtc.send_video_to_peers.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(create_name, return_value = None), \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
|
||||
|
||||
# TODO: refine test
|
||||
def test_run_opus_encode_loop() -> None:
|
||||
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
|
||||
|
||||
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_called_once()
|
||||
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
assert mock_rtc.send_audio_to_peers.call_count == 2
|
||||
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
|
||||
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'), \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_not_called()
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'), \
|
||||
patch('facefusion.apis.stream_helper.rtc'):
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(b'')
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_not_called()
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
|
||||
# TODO: refine test
|
||||
def test_handle_video_stream() -> None:
|
||||
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
video_packet = _make_video_packet(frame)
|
||||
audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32))
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
||||
patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \
|
||||
patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
websocket.accept.assert_called_once_with(subprotocol = 'proto')
|
||||
websocket.send_text.assert_called_once_with('ready')
|
||||
websocket.close.assert_called_once()
|
||||
mock_rtc.create_rtc_peers.assert_called_once_with('session-1')
|
||||
mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1')
|
||||
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
|
||||
assert loop_session_id == 'session-1'
|
||||
assert loop_resolution == (64, 64)
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
websocket.accept.assert_called_once()
|
||||
websocket.send_text.assert_not_called()
|
||||
mock_rtc.create_rtc_peers.assert_not_called()
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
||||
patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \
|
||||
patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'):
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
audio_queue = mock_audio_loop.call_args[0][0]
|
||||
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'
|
||||
Reference in New Issue
Block a user