mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-02 19:01:35 +02:00
0019d3ad0f
* rearrange methods following the flow * add test_stream_helper.py * fix lint * fix lint * refactor audio flow to match video by replacing dequeue with queue * remove unused keyframe interval * remove try block * remove while True * simplify run_aom_encode_loop and run_vp8_encode_loop * cleanup names * simplify run_opus_encode_loop * move opus_encoder creation to run_opus_encode_loop * add todos * fix lint * update todos and tests
241 lines
11 KiB
Python
241 lines
11 KiB
Python
import asyncio
|
|
import queue
|
|
from typing import Any, Optional
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import cv2
|
|
import numpy
|
|
import pytest
|
|
from numpy.typing import NDArray
|
|
from starlette.websockets import WebSocketState
|
|
|
|
from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop
|
|
from facefusion.hash_helper import create_hash
|
|
from facefusion.types import VisionFrame
|
|
|
|
|
|
def _make_handler_websocket(events : list[Any]) -> MagicMock:
|
|
mock = MagicMock()
|
|
mock.scope = {}
|
|
mock.client_state = WebSocketState.CONNECTED
|
|
mock.accept = AsyncMock()
|
|
mock.send_text = AsyncMock()
|
|
mock.close = AsyncMock()
|
|
mock.receive = AsyncMock(side_effect = events)
|
|
return mock
|
|
|
|
|
|
def _make_video_packet(frame : NDArray[Any]) -> bytes:
|
|
_, encoded = cv2.imencode('.jpg', frame)
|
|
return b'\x01' + encoded.tobytes()
|
|
|
|
|
|
def _make_audio_packet(samples : NDArray[Any]) -> bytes:
|
|
return b'\x02' + samples.tobytes()
|
|
|
|
|
|
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
|
|
def test_run_video_encode_loop(video_codec : str) -> None:
|
|
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
|
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
|
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
|
|
black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8)
|
|
prefix = 'facefusion.apis.stream_helper.'
|
|
|
|
create_name = prefix + 'create_aom_encoder'
|
|
encode_name = prefix + 'encode_aom_buffer'
|
|
destroy_name = prefix + 'destroy_aom_encoder'
|
|
run_loop = run_aom_encode_loop
|
|
|
|
if video_codec == 'vp8':
|
|
create_name = prefix + 'create_vpx_encoder'
|
|
encode_name = prefix + 'encode_vpx_buffer'
|
|
destroy_name = prefix + 'destroy_vpx_encoder'
|
|
run_loop = run_vp8_encode_loop
|
|
|
|
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
|
|
vision_frame_queue.put(frame)
|
|
vision_frame_queue.put(None)
|
|
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
|
patch(create_name, return_value = MagicMock()), \
|
|
patch(encode_name, return_value = b'encoded'), \
|
|
patch(destroy_name), \
|
|
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
|
patch(prefix + 'rtc') as mock_rtc:
|
|
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
|
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
|
mock_rtc.send_video_to_peers.assert_called_once()
|
|
|
|
vision_frame_queue = queue.Queue()
|
|
vision_frame_queue.put(frame)
|
|
vision_frame_queue.put(frame)
|
|
vision_frame_queue.put(frame)
|
|
vision_frame_queue.put(None)
|
|
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
|
patch(create_name, return_value = MagicMock()), \
|
|
patch(encode_name, return_value = b'encoded'), \
|
|
patch(destroy_name), \
|
|
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
|
patch(prefix + 'rtc') as mock_rtc:
|
|
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
|
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
|
assert mock_rtc.send_video_to_peers.call_count == 3
|
|
|
|
vision_frame_queue = queue.Queue()
|
|
vision_frame_queue.put(black_frame)
|
|
with patch(create_name, return_value = MagicMock()), \
|
|
patch(destroy_name) as mock_destroy, \
|
|
patch(prefix + 'rtc') as mock_rtc:
|
|
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
|
mock_rtc.send_video_to_peers.assert_not_called()
|
|
mock_destroy.assert_called_once()
|
|
|
|
vision_frame_queue = queue.Queue()
|
|
vision_frame_queue.put(frame)
|
|
vision_frame_queue.put(None)
|
|
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
|
patch(create_name, return_value = MagicMock()), \
|
|
patch(encode_name, return_value = b''), \
|
|
patch(destroy_name), \
|
|
patch(prefix + 'rtc_store'), \
|
|
patch(prefix + 'rtc') as mock_rtc:
|
|
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
|
mock_rtc.send_video_to_peers.assert_not_called()
|
|
|
|
vision_frame_queue = queue.Queue()
|
|
vision_frame_queue.put(small_frame)
|
|
vision_frame_queue.put(None)
|
|
with patch(prefix + 'process_vision_frame', return_value = large_frame), \
|
|
patch(create_name, return_value = MagicMock()) as mock_create, \
|
|
patch(encode_name, return_value = b'encoded'), \
|
|
patch(destroy_name) as mock_destroy, \
|
|
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
|
patch(prefix + 'rtc') as mock_rtc:
|
|
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
|
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
|
assert mock_create.call_count == 2
|
|
assert mock_destroy.call_count == 2
|
|
mock_rtc.send_video_to_peers.assert_called_once()
|
|
|
|
vision_frame_queue = queue.Queue()
|
|
vision_frame_queue.put(frame)
|
|
vision_frame_queue.put(None)
|
|
with patch(create_name, return_value = None), \
|
|
patch(prefix + 'rtc') as mock_rtc:
|
|
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
|
mock_rtc.send_video_to_peers.assert_not_called()
|
|
|
|
|
|
# TODO: refine test
|
|
def test_run_opus_encode_loop() -> None:
|
|
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
|
|
|
|
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
|
|
audio_chunk_queue.put(audio_chunk)
|
|
audio_chunk_queue.put(None)
|
|
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
|
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
|
|
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
|
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
|
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
|
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
|
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
|
mock_rtc.send_audio_to_peers.assert_called_once()
|
|
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
|
|
|
|
audio_chunk_queue = queue.Queue()
|
|
audio_chunk_queue.put(audio_chunk)
|
|
audio_chunk_queue.put(audio_chunk)
|
|
audio_chunk_queue.put(None)
|
|
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
|
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
|
|
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
|
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
|
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
|
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
|
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
|
assert mock_rtc.send_audio_to_peers.call_count == 2
|
|
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
|
|
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
|
|
|
|
audio_chunk_queue = queue.Queue()
|
|
audio_chunk_queue.put(audio_chunk)
|
|
audio_chunk_queue.put(None)
|
|
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
|
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \
|
|
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
|
patch('facefusion.apis.stream_helper.rtc_store'), \
|
|
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
|
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
|
mock_rtc.send_audio_to_peers.assert_not_called()
|
|
|
|
audio_chunk_queue = queue.Queue()
|
|
audio_chunk_queue.put(audio_chunk)
|
|
audio_chunk_queue.put(None)
|
|
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
|
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \
|
|
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
|
patch('facefusion.apis.stream_helper.rtc_store'), \
|
|
patch('facefusion.apis.stream_helper.rtc'):
|
|
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
|
mock_destroy.assert_called_once()
|
|
|
|
audio_chunk_queue = queue.Queue()
|
|
audio_chunk_queue.put(b'')
|
|
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
|
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
|
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
|
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
|
mock_rtc.send_audio_to_peers.assert_not_called()
|
|
mock_destroy.assert_called_once()
|
|
|
|
|
|
# TODO: refine test
|
|
def test_handle_video_stream() -> None:
|
|
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
|
video_packet = _make_video_packet(frame)
|
|
audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32))
|
|
|
|
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
|
|
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
|
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
|
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
|
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
|
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
|
patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \
|
|
patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \
|
|
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
|
asyncio.run(handle_video_stream(websocket))
|
|
websocket.accept.assert_called_once_with(subprotocol = 'proto')
|
|
websocket.send_text.assert_called_once_with('ready')
|
|
websocket.close.assert_called_once()
|
|
mock_rtc.create_rtc_peers.assert_called_once_with('session-1')
|
|
mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1')
|
|
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
|
|
assert loop_session_id == 'session-1'
|
|
assert loop_resolution == (64, 64)
|
|
|
|
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
|
|
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
|
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
|
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \
|
|
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
|
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
|
asyncio.run(handle_video_stream(websocket))
|
|
websocket.accept.assert_called_once()
|
|
websocket.send_text.assert_not_called()
|
|
mock_rtc.create_rtc_peers.assert_not_called()
|
|
|
|
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
|
|
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
|
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
|
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
|
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
|
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
|
patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \
|
|
patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \
|
|
patch('facefusion.apis.stream_helper.rtc_store'):
|
|
asyncio.run(handle_video_stream(websocket))
|
|
audio_queue = mock_audio_loop.call_args[0][0]
|
|
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'
|