diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index be9876bb..757b9737 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,5 +1,5 @@ import asyncio -from collections import deque +import queue # TODO: try deque from collections.abc import AsyncIterator from typing import Optional, Tuple, cast, get_args @@ -17,29 +17,97 @@ from facefusion.streamer import process_vision_frame from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame -async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]: - websocket_event = await websocket.receive() +# TODO: refine this method +async def handle_video_stream(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + access_token = extract_access_token(websocket.scope) + session_id = session_manager.find_session_id(access_token) + session_context.set_session_id(session_id) + stream_codec : VideoCodec = 'av1' - while websocket_event.get('type') == 'websocket.receive': - frame_buffer = websocket_event.get('bytes') or bytes() + if websocket.query_params.get('codec') in get_args(VideoCodec): + stream_codec = cast(VideoCodec, websocket.query_params.get('codec')) - if len(frame_buffer) > 1: - yield frame_buffer[0], frame_buffer[1:] + await websocket.accept(subprotocol = subprotocol) - websocket_event = await websocket.receive() + if session_id: + stream_frames = receive_stream_frames(websocket) + first_vision_frame : Optional[VisionFrame] = None + + async for first_frame_type, first_frame_buffer in stream_frames: + if first_frame_type == 1: + first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) + break + + if numpy.any(first_vision_frame): + resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0]) + vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue() + audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue() + audio_temp = numpy.array([], dtype = numpy.float32) + + vision_frame_queue.put(first_vision_frame) + rtc_store.create_rtc_peers(session_id) + + event_loop = asyncio.get_running_loop() + + if stream_codec == 'av1': + video_encode_task = event_loop.run_in_executor(None, run_aom_encode_loop, vision_frame_queue, session_id, resolution) + if stream_codec == 'vp8': + video_encode_task = event_loop.run_in_executor(None, run_vp8_encode_loop, vision_frame_queue, session_id, resolution) + + audio_encode_task = event_loop.run_in_executor(None, run_opus_encode_loop, audio_chunk_queue, session_id) + await websocket.send_text('ready') + + async for frame_type, frame_buffer in stream_frames: + if frame_type == 1: + vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) + + if numpy.any(vision_frame): + if vision_frame_queue.qsize(): + vision_frame_queue.get_nowait() + vision_frame_queue.put(vision_frame) + + if frame_type == 2: + audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ]) + + while len(audio_temp) >= 1920: + audio_chunk_queue.put(audio_temp[:1920].tobytes()) + audio_temp = audio_temp[1920:] + + vision_frame_queue.put(None) + audio_chunk_queue.put(None) + + await video_encode_task + await audio_encode_task + + rtc_store.destroy_rtc_peers(session_id) + + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close() -async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]: - websocket_event = await websocket.receive() +# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream +async def handle_image_stream(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + access_token = extract_access_token(websocket.scope) + session_id = session_manager.find_session_id(access_token) + session_context.set_session_id(session_id) + source_paths = state_manager.get_item('source_paths') - while websocket_event.get('type') == 'websocket.receive': - frame_buffer = websocket_event.get('bytes') or bytes() - vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) + await websocket.accept(subprotocol = subprotocol) - if numpy.any(vision_frame): - yield vision_frame + if source_paths: + capture_vision_frame = await anext(receive_vision_frames(websocket), None) - websocket_event = await websocket.receive() + if numpy.any(capture_vision_frame): + output_vision_frame = process_vision_frame(capture_vision_frame) + is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame) + + if is_success: + await websocket.send_bytes(output_frame_buffer.tobytes()) + + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close() # TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop @@ -76,170 +144,115 @@ def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[Sdp return None -def run_aom_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None: - aom_encoder = create_aom_encoder(initial_resolution, 4500, 8, 10) - current_resolution = initial_resolution - pts = 0 +# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards +def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None: + aom_encoder = create_aom_encoder(frame_resolution, 4500, 8, 10) + temp_resolution = frame_resolution + timestamp = 0 - while vision_frame_deque: - vision_frame = vision_frame_deque[-1] - output_frame = process_vision_frame(vision_frame) - frame_resolution = (output_frame.shape[1], output_frame.shape[0]) + vision_frame = vision_frame_queue.get() - if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]: - if aom_encoder: - destroy_aom_encoder(aom_encoder) + while numpy.any(vision_frame) and aom_encoder: + output_vision_frame = process_vision_frame(vision_frame) + output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) - current_resolution = frame_resolution - aom_encoder = create_aom_encoder(current_resolution, 4500, 8, 10) - pts = 0 + if output_resolution == temp_resolution: + output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp) + rtc_peers = rtc_store.get_rtc_peers(session_id) - if aom_encoder: - yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420) - frame_buffer = encode_aom_buffer(aom_encoder, yuv_frame.tobytes(), frame_resolution, pts) + if output_frame_buffer and rtc_peers: + rtc.send_video_to_peers(rtc_peers, output_frame_buffer) - if frame_buffer: - rtc_peers = rtc_store.get_rtc_peers(session_id) + timestamp += 1 + vision_frame = vision_frame_queue.get() + continue - if rtc_peers: - rtc.send_video_to_peers(rtc_peers, frame_buffer) - - pts += 1 + destroy_aom_encoder(aom_encoder) + temp_resolution = output_resolution + aom_encoder = create_aom_encoder(temp_resolution, 4500, 8, 10) + timestamp = 0 if aom_encoder: destroy_aom_encoder(aom_encoder) -def run_vp8_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None: - vpx_encoder = create_vpx_encoder(initial_resolution, 4500, 8, 16) - current_resolution = initial_resolution - pts = 0 +# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards +def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None: + vpx_encoder = create_vpx_encoder(frame_resolution, 4500, 8, 16) + temp_resolution = frame_resolution + timestamp = 0 - while vision_frame_deque: - vision_frame = vision_frame_deque[-1] - output_frame = process_vision_frame(vision_frame) - frame_resolution = (output_frame.shape[1], output_frame.shape[0]) + vision_frame = vision_frame_queue.get() - if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]: - if vpx_encoder: - destroy_vpx_encoder(vpx_encoder) + while numpy.any(vision_frame) and vpx_encoder: + output_vision_frame = process_vision_frame(vision_frame) + output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) - current_resolution = frame_resolution - vpx_encoder = create_vpx_encoder(current_resolution, 4500, 8, 16) - pts = 0 + if output_resolution == temp_resolution: + output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp) + rtc_peers = rtc_store.get_rtc_peers(session_id) - if vpx_encoder: - yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420) - frame_buffer = encode_vpx_buffer(vpx_encoder, yuv_frame.tobytes(), frame_resolution, pts) + if output_frame_buffer and rtc_peers: + rtc.send_video_to_peers(rtc_peers, output_frame_buffer) - if frame_buffer: - rtc_peers = rtc_store.get_rtc_peers(session_id) + timestamp += 1 + vision_frame = vision_frame_queue.get() + continue - if rtc_peers: - rtc.send_video_to_peers(rtc_peers, frame_buffer) - - pts += 1 + destroy_vpx_encoder(vpx_encoder) + temp_resolution = output_resolution + vpx_encoder = create_vpx_encoder(temp_resolution, 4500, 8, 16) + timestamp = 0 if vpx_encoder: destroy_vpx_encoder(vpx_encoder) -# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream -async def handle_image_stream(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - session_context.set_session_id(session_id) - source_paths = state_manager.get_item('source_paths') +# TODO: switch to loop_encode_audio or encode_audio_loop ... pass audio_codec to follow standards +def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None: + opus_encoder = create_opus_encoder(48000, 2) + audio_timestamp = 0 - await websocket.accept(subprotocol = subprotocol) + audio_chunk = audio_chunk_queue.get() - if source_paths: - capture_vision_frame = await anext(receive_vision_frames(websocket), None) + while audio_chunk: # TODO: improve this condition with b'' + audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960) + rtc_peers = rtc_store.get_rtc_peers(session_id) - if numpy.any(capture_vision_frame): - output_vision_frame = process_vision_frame(capture_vision_frame) - is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame) + if audio_buffer and rtc_peers: + rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp) - if is_success: - await websocket.send_bytes(output_frame_buffer.tobytes()) + audio_timestamp += 960 + audio_chunk = audio_chunk_queue.get() - if websocket.client_state == WebSocketState.CONNECTED: - await websocket.close() + if opus_encoder: + destroy_opus_encoder(opus_encoder) -async def handle_video_stream(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - session_context.set_session_id(session_id) - stream_codec : VideoCodec = 'av1' +# TODO: needs refinement +async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]: + websocket_event = await websocket.receive() - if websocket.query_params.get('codec') in get_args(VideoCodec): - stream_codec = cast(VideoCodec, websocket.query_params.get('codec')) + while websocket_event.get('type') == 'websocket.receive': + frame_buffer = websocket_event.get('bytes') or bytes() - await websocket.accept(subprotocol = subprotocol) + if len(frame_buffer) > 1: + yield frame_buffer[0], frame_buffer[1:] - if session_id: - stream_frames = receive_stream_frames(websocket) - first_vision_frame = None + websocket_event = await websocket.receive() - # TODO: audio frames may arrive before video due to ScriptProcessor firing faster than canvas toBlob - async for first_frame_type, first_frame_buffer in stream_frames: - if first_frame_type == 1: - first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) - break - if numpy.any(first_vision_frame): - resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0]) - keyframe_interval = int(state_manager.get_item('output_video_fps') or 30) # TODO: remove hardcoded via stream_video_fps - vision_frame_deque : deque[VisionFrame] = deque(maxlen = 1) - opus_encoder = create_opus_encoder(48000, 2) # TODO: guard against opus_encoder being None - audio_temp = numpy.array([], dtype = numpy.float32) - audio_timestamp = 0 +# TODO: needs refinement, does it receive frames or a buffer? +async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]: + websocket_event = await websocket.receive() - vision_frame_deque.append(first_vision_frame) - rtc_store.create_rtc_peers(session_id) + while websocket_event.get('type') == 'websocket.receive': + frame_buffer = websocket_event.get('bytes') or bytes() + vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) - event_loop = asyncio.get_running_loop() - encode_loop = run_aom_encode_loop + if numpy.any(vision_frame): + yield vision_frame - if stream_codec == 'vp8': - encode_loop = run_vp8_encode_loop - - video_encode_task = event_loop.run_in_executor(None, encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval) - await websocket.send_text('ready') - - async for frame_type, frame_buffer in stream_frames: - if frame_type == 1: - vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(vision_frame): - vision_frame_deque.append(vision_frame) - - if frame_type == 2: - audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ]) - - while len(audio_temp) >= 1920: - audio_chunk = audio_temp[:1920] - audio_temp = audio_temp[1920:] - audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk.tobytes(), 960) - - if audio_buffer: - rtc_peers = rtc_store.get_rtc_peers(session_id) - - if rtc_peers: - rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp) - - audio_timestamp += 960 - - vision_frame_deque.clear() - await video_encode_task - - if opus_encoder: - destroy_opus_encoder(opus_encoder) - - rtc_store.destroy_rtc_peers(session_id) - - if websocket.client_state == WebSocketState.CONNECTED: - await websocket.close() + websocket_event = await websocket.receive() diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py new file mode 100644 index 00000000..48832759 --- /dev/null +++ b/tests/test_stream_helper.py @@ -0,0 +1,240 @@ +import asyncio +import queue +from typing import Any, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import cv2 +import numpy +import pytest +from numpy.typing import NDArray +from starlette.websockets import WebSocketState + +from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop +from facefusion.hash_helper import create_hash +from facefusion.types import VisionFrame + + +def _make_handler_websocket(events : list[Any]) -> MagicMock: + mock = MagicMock() + mock.scope = {} + mock.client_state = WebSocketState.CONNECTED + mock.accept = AsyncMock() + mock.send_text = AsyncMock() + mock.close = AsyncMock() + mock.receive = AsyncMock(side_effect = events) + return mock + + +def _make_video_packet(frame : NDArray[Any]) -> bytes: + _, encoded = cv2.imencode('.jpg', frame) + return b'\x01' + encoded.tobytes() + + +def _make_audio_packet(samples : NDArray[Any]) -> bytes: + return b'\x02' + samples.tobytes() + + +@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) +def test_run_video_encode_loop(video_codec : str) -> None: + frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) + small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) + large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8) + black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8) + prefix = 'facefusion.apis.stream_helper.' + + create_name = prefix + 'create_aom_encoder' + encode_name = prefix + 'encode_aom_buffer' + destroy_name = prefix + 'destroy_aom_encoder' + run_loop = run_aom_encode_loop + + if video_codec == 'vp8': + create_name = prefix + 'create_vpx_encoder' + encode_name = prefix + 'encode_vpx_buffer' + destroy_name = prefix + 'destroy_vpx_encoder' + run_loop = run_vp8_encode_loop + + vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue() + vision_frame_queue.put(frame) + vision_frame_queue.put(None) + with patch(prefix + 'process_vision_frame', return_value = frame), \ + patch(create_name, return_value = MagicMock()), \ + patch(encode_name, return_value = b'encoded'), \ + patch(destroy_name), \ + patch(prefix + 'rtc_store') as mock_rtc_store, \ + patch(prefix + 'rtc') as mock_rtc: + mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] + run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc.send_video_to_peers.assert_called_once() + + vision_frame_queue = queue.Queue() + vision_frame_queue.put(frame) + vision_frame_queue.put(frame) + vision_frame_queue.put(frame) + vision_frame_queue.put(None) + with patch(prefix + 'process_vision_frame', return_value = frame), \ + patch(create_name, return_value = MagicMock()), \ + patch(encode_name, return_value = b'encoded'), \ + patch(destroy_name), \ + patch(prefix + 'rtc_store') as mock_rtc_store, \ + patch(prefix + 'rtc') as mock_rtc: + mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] + run_loop(vision_frame_queue, 'session-1', (64, 64)) + assert mock_rtc.send_video_to_peers.call_count == 3 + + vision_frame_queue = queue.Queue() + vision_frame_queue.put(black_frame) + with patch(create_name, return_value = MagicMock()), \ + patch(destroy_name) as mock_destroy, \ + patch(prefix + 'rtc') as mock_rtc: + run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc.send_video_to_peers.assert_not_called() + mock_destroy.assert_called_once() + + vision_frame_queue = queue.Queue() + vision_frame_queue.put(frame) + vision_frame_queue.put(None) + with patch(prefix + 'process_vision_frame', return_value = frame), \ + patch(create_name, return_value = MagicMock()), \ + patch(encode_name, return_value = b''), \ + patch(destroy_name), \ + patch(prefix + 'rtc_store'), \ + patch(prefix + 'rtc') as mock_rtc: + run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc.send_video_to_peers.assert_not_called() + + vision_frame_queue = queue.Queue() + vision_frame_queue.put(small_frame) + vision_frame_queue.put(None) + with patch(prefix + 'process_vision_frame', return_value = large_frame), \ + patch(create_name, return_value = MagicMock()) as mock_create, \ + patch(encode_name, return_value = b'encoded'), \ + patch(destroy_name) as mock_destroy, \ + patch(prefix + 'rtc_store') as mock_rtc_store, \ + patch(prefix + 'rtc') as mock_rtc: + mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] + run_loop(vision_frame_queue, 'session-1', (64, 64)) + assert mock_create.call_count == 2 + assert mock_destroy.call_count == 2 + mock_rtc.send_video_to_peers.assert_called_once() + + vision_frame_queue = queue.Queue() + vision_frame_queue.put(frame) + vision_frame_queue.put(None) + with patch(create_name, return_value = None), \ + patch(prefix + 'rtc') as mock_rtc: + run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc.send_video_to_peers.assert_not_called() + + +# TODO: refine test +def test_run_opus_encode_loop() -> None: + audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes() + + audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue() + audio_chunk_queue.put(audio_chunk) + audio_chunk_queue.put(None) + with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ + patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \ + patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ + patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \ + patch('facefusion.apis.stream_helper.rtc') as mock_rtc: + mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] + run_opus_encode_loop(audio_chunk_queue, 'session-1') + mock_rtc.send_audio_to_peers.assert_called_once() + assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0 + + audio_chunk_queue = queue.Queue() + audio_chunk_queue.put(audio_chunk) + audio_chunk_queue.put(audio_chunk) + audio_chunk_queue.put(None) + with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ + patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \ + patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ + patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \ + patch('facefusion.apis.stream_helper.rtc') as mock_rtc: + mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] + run_opus_encode_loop(audio_chunk_queue, 'session-1') + assert mock_rtc.send_audio_to_peers.call_count == 2 + assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0 + assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960 + + audio_chunk_queue = queue.Queue() + audio_chunk_queue.put(audio_chunk) + audio_chunk_queue.put(None) + with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ + patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \ + patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ + patch('facefusion.apis.stream_helper.rtc_store'), \ + patch('facefusion.apis.stream_helper.rtc') as mock_rtc: + run_opus_encode_loop(audio_chunk_queue, 'session-1') + mock_rtc.send_audio_to_peers.assert_not_called() + + audio_chunk_queue = queue.Queue() + audio_chunk_queue.put(audio_chunk) + audio_chunk_queue.put(None) + with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ + patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \ + patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \ + patch('facefusion.apis.stream_helper.rtc_store'), \ + patch('facefusion.apis.stream_helper.rtc'): + run_opus_encode_loop(audio_chunk_queue, 'session-1') + mock_destroy.assert_called_once() + + audio_chunk_queue = queue.Queue() + audio_chunk_queue.put(b'') + with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ + patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \ + patch('facefusion.apis.stream_helper.rtc') as mock_rtc: + run_opus_encode_loop(audio_chunk_queue, 'session-1') + mock_rtc.send_audio_to_peers.assert_not_called() + mock_destroy.assert_called_once() + + +# TODO: refine test +def test_handle_video_stream() -> None: + frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) + video_packet = _make_video_packet(frame) + audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32)) + + websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ]) + with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \ + patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \ + patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \ + patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ + patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \ + patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \ + patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \ + patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc: + asyncio.run(handle_video_stream(websocket)) + websocket.accept.assert_called_once_with(subprotocol = 'proto') + websocket.send_text.assert_called_once_with('ready') + websocket.close.assert_called_once() + mock_rtc.create_rtc_peers.assert_called_once_with('session-1') + mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1') + _, loop_session_id, loop_resolution = mock_loop.call_args[0] + assert loop_session_id == 'session-1' + assert loop_resolution == (64, 64) + + websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ]) + with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \ + patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \ + patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \ + patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ + patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc: + asyncio.run(handle_video_stream(websocket)) + websocket.accept.assert_called_once() + websocket.send_text.assert_not_called() + mock_rtc.create_rtc_peers.assert_not_called() + + websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ]) + with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \ + patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \ + patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \ + patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ + patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \ + patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \ + patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \ + patch('facefusion.apis.stream_helper.rtc_store'): + asyncio.run(handle_video_stream(websocket)) + audio_queue = mock_audio_loop.call_args[0][0] + assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'