Files
facefusion/tests/test_stream_helper.py
T
Harisreedhar 0019d3ad0f Refactor stream_helper: queue-based audio/video loops with unified threading (#1116)
* rearrange methods following the flow

* add test_stream_helper.py

* fix lint

* fix lint

* refactor audio flow to match video by replacing dequeue with queue

* remove unused keyframe interval

* remove try block

* remove while True

* simplify run_aom_encode_loop and run_vp8_encode_loop

* cleanup names

* simplify run_opus_encode_loop

* move opus_encoder creation to run_opus_encode_loop

* add todos

* fix lint

* update todos and tests
2026-05-15 21:47:30 +05:30

241 lines
11 KiB
Python

import asyncio
import queue
from typing import Any, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import cv2
import numpy
import pytest
from numpy.typing import NDArray
from starlette.websockets import WebSocketState
from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop
from facefusion.hash_helper import create_hash
from facefusion.types import VisionFrame
def _make_handler_websocket(events : list[Any]) -> MagicMock:
mock = MagicMock()
mock.scope = {}
mock.client_state = WebSocketState.CONNECTED
mock.accept = AsyncMock()
mock.send_text = AsyncMock()
mock.close = AsyncMock()
mock.receive = AsyncMock(side_effect = events)
return mock
def _make_video_packet(frame : NDArray[Any]) -> bytes:
_, encoded = cv2.imencode('.jpg', frame)
return b'\x01' + encoded.tobytes()
def _make_audio_packet(samples : NDArray[Any]) -> bytes:
return b'\x02' + samples.tobytes()
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_run_video_encode_loop(video_codec : str) -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8)
prefix = 'facefusion.apis.stream_helper.'
create_name = prefix + 'create_aom_encoder'
encode_name = prefix + 'encode_aom_buffer'
destroy_name = prefix + 'destroy_aom_encoder'
run_loop = run_aom_encode_loop
if video_codec == 'vp8':
create_name = prefix + 'create_vpx_encoder'
encode_name = prefix + 'encode_vpx_buffer'
destroy_name = prefix + 'destroy_vpx_encoder'
run_loop = run_vp8_encode_loop
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(frame)
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
assert mock_rtc.send_video_to_peers.call_count == 3
vision_frame_queue = queue.Queue()
vision_frame_queue.put(black_frame)
with patch(create_name, return_value = MagicMock()), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
mock_destroy.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b''), \
patch(destroy_name), \
patch(prefix + 'rtc_store'), \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(small_frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = large_frame), \
patch(create_name, return_value = MagicMock()) as mock_create, \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
assert mock_create.call_count == 2
assert mock_destroy.call_count == 2
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(create_name, return_value = None), \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
# TODO: refine test
def test_run_opus_encode_loop() -> None:
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_called_once()
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_opus_encode_loop(audio_chunk_queue, 'session-1')
assert mock_rtc.send_audio_to_peers.call_count == 2
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc'):
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_destroy.assert_called_once()
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(b'')
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
mock_destroy.assert_called_once()
# TODO: refine test
def test_handle_video_stream() -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
video_packet = _make_video_packet(frame)
audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32))
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \
patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once_with(subprotocol = 'proto')
websocket.send_text.assert_called_once_with('ready')
websocket.close.assert_called_once()
mock_rtc.create_rtc_peers.assert_called_once_with('session-1')
mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1')
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
assert loop_session_id == 'session-1'
assert loop_resolution == (64, 64)
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once()
websocket.send_text.assert_not_called()
mock_rtc.create_rtc_peers.assert_not_called()
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \
patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \
patch('facefusion.apis.stream_helper.rtc_store'):
asyncio.run(handle_video_stream(websocket))
audio_queue = mock_audio_loop.call_args[0][0]
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'