Refactor stream_helper: queue-based audio/video loops with unified threading (#1116)

* rearrange methods following the flow

* add test_stream_helper.py

* fix lint

* fix lint

* refactor audio flow to match video by replacing dequeue with queue

* remove unused keyframe interval

* remove try block

* remove while True

* simplify run_aom_encode_loop and run_vp8_encode_loop

* cleanup names

* simplify run_opus_encode_loop

* move opus_encoder creation to run_opus_encode_loop

* add todos

* fix lint

* update todos and tests
This commit is contained in:
Harisreedhar
2026-05-15 21:47:30 +05:30
committed by GitHub
parent 532464032b
commit 0019d3ad0f
2 changed files with 399 additions and 146 deletions
+159 -146
View File
@@ -1,5 +1,5 @@
import asyncio
from collections import deque
import queue # TODO: try deque
from collections.abc import AsyncIterator
from typing import Optional, Tuple, cast, get_args
@@ -17,29 +17,97 @@ from facefusion.streamer import process_vision_frame
from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]:
websocket_event = await websocket.receive()
# TODO: refine this method
async def handle_video_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
stream_codec : VideoCodec = 'av1'
while websocket_event.get('type') == 'websocket.receive':
frame_buffer = websocket_event.get('bytes') or bytes()
if websocket.query_params.get('codec') in get_args(VideoCodec):
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
if len(frame_buffer) > 1:
yield frame_buffer[0], frame_buffer[1:]
await websocket.accept(subprotocol = subprotocol)
websocket_event = await websocket.receive()
if session_id:
stream_frames = receive_stream_frames(websocket)
first_vision_frame : Optional[VisionFrame] = None
async for first_frame_type, first_frame_buffer in stream_frames:
if first_frame_type == 1:
first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
break
if numpy.any(first_vision_frame):
resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0])
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
audio_temp = numpy.array([], dtype = numpy.float32)
vision_frame_queue.put(first_vision_frame)
rtc_store.create_rtc_peers(session_id)
event_loop = asyncio.get_running_loop()
if stream_codec == 'av1':
video_encode_task = event_loop.run_in_executor(None, run_aom_encode_loop, vision_frame_queue, session_id, resolution)
if stream_codec == 'vp8':
video_encode_task = event_loop.run_in_executor(None, run_vp8_encode_loop, vision_frame_queue, session_id, resolution)
audio_encode_task = event_loop.run_in_executor(None, run_opus_encode_loop, audio_chunk_queue, session_id)
await websocket.send_text('ready')
async for frame_type, frame_buffer in stream_frames:
if frame_type == 1:
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(vision_frame):
if vision_frame_queue.qsize():
vision_frame_queue.get_nowait()
vision_frame_queue.put(vision_frame)
if frame_type == 2:
audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ])
while len(audio_temp) >= 1920:
audio_chunk_queue.put(audio_temp[:1920].tobytes())
audio_temp = audio_temp[1920:]
vision_frame_queue.put(None)
audio_chunk_queue.put(None)
await video_encode_task
await audio_encode_task
rtc_store.destroy_rtc_peers(session_id)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
websocket_event = await websocket.receive()
# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream
async def handle_image_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
while websocket_event.get('type') == 'websocket.receive':
frame_buffer = websocket_event.get('bytes') or bytes()
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
await websocket.accept(subprotocol = subprotocol)
if numpy.any(vision_frame):
yield vision_frame
if source_paths:
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
websocket_event = await websocket.receive()
if numpy.any(capture_vision_frame):
output_vision_frame = process_vision_frame(capture_vision_frame)
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
if is_success:
await websocket.send_bytes(output_frame_buffer.tobytes())
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
# TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop
@@ -76,170 +144,115 @@ def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[Sdp
return None
def run_aom_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
aom_encoder = create_aom_encoder(initial_resolution, 4500, 8, 10)
current_resolution = initial_resolution
pts = 0
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
aom_encoder = create_aom_encoder(frame_resolution, 4500, 8, 10)
temp_resolution = frame_resolution
timestamp = 0
while vision_frame_deque:
vision_frame = vision_frame_deque[-1]
output_frame = process_vision_frame(vision_frame)
frame_resolution = (output_frame.shape[1], output_frame.shape[0])
vision_frame = vision_frame_queue.get()
if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]:
if aom_encoder:
destroy_aom_encoder(aom_encoder)
while numpy.any(vision_frame) and aom_encoder:
output_vision_frame = process_vision_frame(vision_frame)
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
current_resolution = frame_resolution
aom_encoder = create_aom_encoder(current_resolution, 4500, 8, 10)
pts = 0
if output_resolution == temp_resolution:
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp)
rtc_peers = rtc_store.get_rtc_peers(session_id)
if aom_encoder:
yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420)
frame_buffer = encode_aom_buffer(aom_encoder, yuv_frame.tobytes(), frame_resolution, pts)
if output_frame_buffer and rtc_peers:
rtc.send_video_to_peers(rtc_peers, output_frame_buffer)
if frame_buffer:
rtc_peers = rtc_store.get_rtc_peers(session_id)
timestamp += 1
vision_frame = vision_frame_queue.get()
continue
if rtc_peers:
rtc.send_video_to_peers(rtc_peers, frame_buffer)
pts += 1
destroy_aom_encoder(aom_encoder)
temp_resolution = output_resolution
aom_encoder = create_aom_encoder(temp_resolution, 4500, 8, 10)
timestamp = 0
if aom_encoder:
destroy_aom_encoder(aom_encoder)
def run_vp8_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
vpx_encoder = create_vpx_encoder(initial_resolution, 4500, 8, 16)
current_resolution = initial_resolution
pts = 0
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
vpx_encoder = create_vpx_encoder(frame_resolution, 4500, 8, 16)
temp_resolution = frame_resolution
timestamp = 0
while vision_frame_deque:
vision_frame = vision_frame_deque[-1]
output_frame = process_vision_frame(vision_frame)
frame_resolution = (output_frame.shape[1], output_frame.shape[0])
vision_frame = vision_frame_queue.get()
if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]:
if vpx_encoder:
destroy_vpx_encoder(vpx_encoder)
while numpy.any(vision_frame) and vpx_encoder:
output_vision_frame = process_vision_frame(vision_frame)
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
current_resolution = frame_resolution
vpx_encoder = create_vpx_encoder(current_resolution, 4500, 8, 16)
pts = 0
if output_resolution == temp_resolution:
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp)
rtc_peers = rtc_store.get_rtc_peers(session_id)
if vpx_encoder:
yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420)
frame_buffer = encode_vpx_buffer(vpx_encoder, yuv_frame.tobytes(), frame_resolution, pts)
if output_frame_buffer and rtc_peers:
rtc.send_video_to_peers(rtc_peers, output_frame_buffer)
if frame_buffer:
rtc_peers = rtc_store.get_rtc_peers(session_id)
timestamp += 1
vision_frame = vision_frame_queue.get()
continue
if rtc_peers:
rtc.send_video_to_peers(rtc_peers, frame_buffer)
pts += 1
destroy_vpx_encoder(vpx_encoder)
temp_resolution = output_resolution
vpx_encoder = create_vpx_encoder(temp_resolution, 4500, 8, 16)
timestamp = 0
if vpx_encoder:
destroy_vpx_encoder(vpx_encoder)
# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream
async def handle_image_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
# TODO: switch to loop_encode_audio or encode_audio_loop ... pass audio_codec to follow standards
def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
opus_encoder = create_opus_encoder(48000, 2)
audio_timestamp = 0
await websocket.accept(subprotocol = subprotocol)
audio_chunk = audio_chunk_queue.get()
if source_paths:
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
while audio_chunk: # TODO: improve this condition with b''
audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960)
rtc_peers = rtc_store.get_rtc_peers(session_id)
if numpy.any(capture_vision_frame):
output_vision_frame = process_vision_frame(capture_vision_frame)
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
if audio_buffer and rtc_peers:
rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp)
if is_success:
await websocket.send_bytes(output_frame_buffer.tobytes())
audio_timestamp += 960
audio_chunk = audio_chunk_queue.get()
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
if opus_encoder:
destroy_opus_encoder(opus_encoder)
async def handle_video_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
stream_codec : VideoCodec = 'av1'
# TODO: needs refinement
async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]:
websocket_event = await websocket.receive()
if websocket.query_params.get('codec') in get_args(VideoCodec):
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
while websocket_event.get('type') == 'websocket.receive':
frame_buffer = websocket_event.get('bytes') or bytes()
await websocket.accept(subprotocol = subprotocol)
if len(frame_buffer) > 1:
yield frame_buffer[0], frame_buffer[1:]
if session_id:
stream_frames = receive_stream_frames(websocket)
first_vision_frame = None
websocket_event = await websocket.receive()
# TODO: audio frames may arrive before video due to ScriptProcessor firing faster than canvas toBlob
async for first_frame_type, first_frame_buffer in stream_frames:
if first_frame_type == 1:
first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
break
if numpy.any(first_vision_frame):
resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0])
keyframe_interval = int(state_manager.get_item('output_video_fps') or 30) # TODO: remove hardcoded via stream_video_fps
vision_frame_deque : deque[VisionFrame] = deque(maxlen = 1)
opus_encoder = create_opus_encoder(48000, 2) # TODO: guard against opus_encoder being None
audio_temp = numpy.array([], dtype = numpy.float32)
audio_timestamp = 0
# TODO: needs refinement, does it receive frames or a buffer?
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
websocket_event = await websocket.receive()
vision_frame_deque.append(first_vision_frame)
rtc_store.create_rtc_peers(session_id)
while websocket_event.get('type') == 'websocket.receive':
frame_buffer = websocket_event.get('bytes') or bytes()
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
event_loop = asyncio.get_running_loop()
encode_loop = run_aom_encode_loop
if numpy.any(vision_frame):
yield vision_frame
if stream_codec == 'vp8':
encode_loop = run_vp8_encode_loop
video_encode_task = event_loop.run_in_executor(None, encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval)
await websocket.send_text('ready')
async for frame_type, frame_buffer in stream_frames:
if frame_type == 1:
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(vision_frame):
vision_frame_deque.append(vision_frame)
if frame_type == 2:
audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ])
while len(audio_temp) >= 1920:
audio_chunk = audio_temp[:1920]
audio_temp = audio_temp[1920:]
audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk.tobytes(), 960)
if audio_buffer:
rtc_peers = rtc_store.get_rtc_peers(session_id)
if rtc_peers:
rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp)
audio_timestamp += 960
vision_frame_deque.clear()
await video_encode_task
if opus_encoder:
destroy_opus_encoder(opus_encoder)
rtc_store.destroy_rtc_peers(session_id)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
websocket_event = await websocket.receive()
+240
View File
@@ -0,0 +1,240 @@
import asyncio
import queue
from typing import Any, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import cv2
import numpy
import pytest
from numpy.typing import NDArray
from starlette.websockets import WebSocketState
from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop
from facefusion.hash_helper import create_hash
from facefusion.types import VisionFrame
def _make_handler_websocket(events : list[Any]) -> MagicMock:
mock = MagicMock()
mock.scope = {}
mock.client_state = WebSocketState.CONNECTED
mock.accept = AsyncMock()
mock.send_text = AsyncMock()
mock.close = AsyncMock()
mock.receive = AsyncMock(side_effect = events)
return mock
def _make_video_packet(frame : NDArray[Any]) -> bytes:
_, encoded = cv2.imencode('.jpg', frame)
return b'\x01' + encoded.tobytes()
def _make_audio_packet(samples : NDArray[Any]) -> bytes:
return b'\x02' + samples.tobytes()
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_run_video_encode_loop(video_codec : str) -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8)
prefix = 'facefusion.apis.stream_helper.'
create_name = prefix + 'create_aom_encoder'
encode_name = prefix + 'encode_aom_buffer'
destroy_name = prefix + 'destroy_aom_encoder'
run_loop = run_aom_encode_loop
if video_codec == 'vp8':
create_name = prefix + 'create_vpx_encoder'
encode_name = prefix + 'encode_vpx_buffer'
destroy_name = prefix + 'destroy_vpx_encoder'
run_loop = run_vp8_encode_loop
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(frame)
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
assert mock_rtc.send_video_to_peers.call_count == 3
vision_frame_queue = queue.Queue()
vision_frame_queue.put(black_frame)
with patch(create_name, return_value = MagicMock()), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
mock_destroy.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b''), \
patch(destroy_name), \
patch(prefix + 'rtc_store'), \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(small_frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = large_frame), \
patch(create_name, return_value = MagicMock()) as mock_create, \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
assert mock_create.call_count == 2
assert mock_destroy.call_count == 2
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(create_name, return_value = None), \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
# TODO: refine test
def test_run_opus_encode_loop() -> None:
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_called_once()
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_opus_encode_loop(audio_chunk_queue, 'session-1')
assert mock_rtc.send_audio_to_peers.call_count == 2
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc'):
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_destroy.assert_called_once()
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(b'')
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
mock_destroy.assert_called_once()
# TODO: refine test
def test_handle_video_stream() -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
video_packet = _make_video_packet(frame)
audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32))
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \
patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once_with(subprotocol = 'proto')
websocket.send_text.assert_called_once_with('ready')
websocket.close.assert_called_once()
mock_rtc.create_rtc_peers.assert_called_once_with('session-1')
mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1')
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
assert loop_session_id == 'session-1'
assert loop_resolution == (64, 64)
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once()
websocket.send_text.assert_not_called()
mock_rtc.create_rtc_peers.assert_not_called()
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \
patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \
patch('facefusion.apis.stream_helper.rtc_store'):
asyncio.run(handle_video_stream(websocket))
audio_queue = mock_audio_loop.call_args[0][0]
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'