Files
facefusion/tests/test_stream_helper.py
T
Harisreedhar c48c238f88 Combine encode loop methods (#1119)
* combine run_aom_encode_loop and run_vpx_encode_loop to encode_video_loop

* run_opus_encode_loop -> encode_audio_loop

* use else instead of continue

* rename to video_codec
2026-05-16 23:29:36 +05:30

239 lines
11 KiB
Python

import asyncio
import queue
from typing import Any, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import cv2
import numpy
import pytest
from numpy.typing import NDArray
from starlette.websockets import WebSocketState
from facefusion.apis.stream_helper import encode_audio_loop, encode_video_loop, handle_video_stream
from facefusion.hash_helper import create_hash
from facefusion.types import VideoCodec, VisionFrame
def _make_handler_websocket(events : list[Any]) -> MagicMock:
mock = MagicMock()
mock.scope = {}
mock.client_state = WebSocketState.CONNECTED
mock.accept = AsyncMock()
mock.send_text = AsyncMock()
mock.close = AsyncMock()
mock.receive = AsyncMock(side_effect = events)
return mock
def _make_video_packet(frame : NDArray[Any]) -> bytes:
_, encoded = cv2.imencode('.jpg', frame)
return b'\x01' + encoded.tobytes()
def _make_audio_packet(samples : NDArray[Any]) -> bytes:
return b'\x02' + samples.tobytes()
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_encode_video_loop(video_codec : VideoCodec) -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8)
prefix = 'facefusion.apis.stream_helper.'
create_name = prefix + 'create_aom_encoder'
encode_name = prefix + 'encode_aom_buffer'
destroy_name = prefix + 'destroy_aom_encoder'
if video_codec == 'vp8':
create_name = prefix + 'create_vpx_encoder'
encode_name = prefix + 'encode_vpx_buffer'
destroy_name = prefix + 'destroy_vpx_encoder'
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(frame)
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
assert mock_rtc.send_video_to_peers.call_count == 3
vision_frame_queue = queue.Queue()
vision_frame_queue.put(black_frame)
with patch(create_name, return_value = MagicMock()), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc') as mock_rtc:
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
mock_destroy.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = frame), \
patch(create_name, return_value = MagicMock()), \
patch(encode_name, return_value = b''), \
patch(destroy_name), \
patch(prefix + 'rtc_store'), \
patch(prefix + 'rtc') as mock_rtc:
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(small_frame)
vision_frame_queue.put(None)
with patch(prefix + 'process_vision_frame', return_value = large_frame), \
patch(create_name, return_value = MagicMock()) as mock_create, \
patch(encode_name, return_value = b'encoded'), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
assert mock_create.call_count == 2
assert mock_destroy.call_count == 2
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
vision_frame_queue.put(frame)
vision_frame_queue.put(None)
with patch(create_name, return_value = None), \
patch(prefix + 'rtc') as mock_rtc:
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
# TODO: refine test
def test_encode_audio_loop() -> None:
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_called_once()
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
assert mock_rtc.send_audio_to_peers.call_count == 2
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(audio_chunk)
audio_chunk_queue.put(None)
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc'):
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_destroy.assert_called_once()
audio_chunk_queue = queue.Queue()
audio_chunk_queue.put(b'')
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
mock_destroy.assert_called_once()
# TODO: refine test
def test_handle_video_stream() -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
video_packet = _make_video_packet(frame)
audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32))
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.encode_video_loop') as mock_loop, \
patch('facefusion.apis.stream_helper.encode_audio_loop'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once_with(subprotocol = 'proto')
websocket.send_text.assert_called_once_with('ready')
websocket.close.assert_called_once()
mock_rtc.init_peers.assert_called_once_with('session-1')
mock_rtc.delete_peers.assert_called_once_with('session-1')
_, _, loop_session_id, loop_resolution = mock_loop.call_args[0]
assert loop_session_id == 'session-1'
assert loop_resolution == (64, 64)
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once()
websocket.send_text.assert_not_called()
mock_rtc.init_peers.assert_not_called()
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.encode_video_loop'), \
patch('facefusion.apis.stream_helper.encode_audio_loop') as mock_audio_loop, \
patch('facefusion.apis.stream_helper.rtc_store'):
asyncio.run(handle_video_stream(websocket))
audio_queue = mock_audio_loop.call_args[0][1]
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'