From 0f5f75ba5137034c55befba3091afbc6e63d11db Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Mon, 1 Jun 2026 08:54:37 +0200 Subject: [PATCH] Cleanup/testing suite (#1136) * clean testing suite * clean testing suite part2 * clean testing suite part3 * add todos * extend testing suite and kill some mutants * fix hashes * fix lint * fix test * fix test --- facefusion/apis/stream_helper.py | 2 + ...set_helper.py => test_api_asset_helper.py} | 0 tests/test_api_session.py | 5 +- ...am_helper.py => test_api_stream_helper.py} | 192 ++++++++++-------- tests/test_face_analyser.py | 76 +------ tests/test_time_helper.py | 4 +- 6 files changed, 120 insertions(+), 159 deletions(-) rename tests/{test_asset_helper.py => test_api_asset_helper.py} (100%) rename tests/{test_stream_helper.py => test_api_stream_helper.py} (59%) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 23068782..acb4d1d2 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -18,6 +18,7 @@ from facefusion.libraries import datachannel as datachannel_module from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, BitRate, PeerConnection, Resolution, RtcPeer, RtcPeerAudio, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder +#TODO: remove source_paths guard, process_image should work independent of source_paths since processors decide if they need sources async def process_image(websocket : WebSocket) -> None: source_paths = state_manager.get_item('source_paths') @@ -112,6 +113,7 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr #TODO: method is too complex def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: # TODO: combine video and audio queue + # TODO: update test_receive_video_frames, test_receive_audio_frames with the same approach (deque) video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1) audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4) receiver_threads = [] diff --git a/tests/test_asset_helper.py b/tests/test_api_asset_helper.py similarity index 100% rename from tests/test_asset_helper.py rename to tests/test_api_asset_helper.py diff --git a/tests/test_api_session.py b/tests/test_api_session.py index 85e7178b..e533a1cc 100644 --- a/tests/test_api_session.py +++ b/tests/test_api_session.py @@ -111,6 +111,8 @@ def test_refresh_session(test_client : TestClient) -> None: assert refresh_session_response.status_code == 401 + access_token = create_session_body.get('access_token') + refresh_session_response = test_client.put('/session', json = { 'refresh_token': create_session_body.get('refresh_token') @@ -119,8 +121,7 @@ def test_refresh_session(test_client : TestClient) -> None: assert refresh_session_body.get('access_token') assert refresh_session_body.get('refresh_token') - assert not refresh_session_body.get('access_token') == create_session_body.get('access_token') - + assert session_manager.find_session_id(access_token) is None assert refresh_session_response.status_code == 200 refresh_session_response = test_client.put('/session', json = diff --git a/tests/test_stream_helper.py b/tests/test_api_stream_helper.py similarity index 59% rename from tests/test_stream_helper.py rename to tests/test_api_stream_helper.py index efa09a9b..7a1f1821 100644 --- a/tests/test_stream_helper.py +++ b/tests/test_api_stream_helper.py @@ -7,19 +7,17 @@ from unittest.mock import AsyncMock, MagicMock, patch import cv2 import numpy import pytest -from starlette.websockets import WebSocketState -from tests.assert_helper import get_test_example_file, get_test_examples_directory from facefusion import rtc, rtc_store, state_manager -from facefusion.apis.endpoints.stream import websocket_stream -from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop, update_video_encoder_bitrate -from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder +from facefusion.apis.stream_helper import create_video_decoder, create_video_encoder, decode_video_frame, destroy_stream, destroy_video_decoder, destroy_video_encoder, encode_video_frame, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop, update_video_encoder_bitrate +from facefusion.codecs import aom_encoder, vpx_encoder from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download from facefusion.hash_helper import create_hash from facefusion.libraries import aom as aom_module, datachannel as datachannel_module, opus as opus_module, vpx as vpx_module from facefusion.types import AudioFrame, RtcPeer, VideoCodec, VisionFrame from facefusion.vision import read_video_frame +from .assert_helper import get_test_example_file, get_test_examples_directory @pytest.fixture(scope = 'module', autouse = True) @@ -27,44 +25,50 @@ def before_all() -> None: state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) state_manager.init_item('processors', []) + aom_module.pre_check() + vpx_module.pre_check() + opus_module.pre_check() + datachannel_module.pre_check() + conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4', 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg' ]) - aom_module.pre_check() - vpx_module.pre_check() - opus_module.pre_check() - datachannel_module.pre_check() - @pytest.fixture(scope = 'function', autouse = True) def before_each() -> None: rtc_store.clear() -# TODO: refine test @pytest.mark.anyio async def test_process_image() -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes() websocket_mock = AsyncMock() - websocket_mock.receive.side_effect = [{'type': 'websocket.receive', 'bytes': frame_buffer}] + websocket_mock.receive.side_effect =\ + [ + { + 'type': 'websocket.receive', + 'bytes': frame_buffer + } + ] - state_manager.init_item('source_paths', [get_test_example_file('source.jpg')]) + #TODO: remove init_item once source_paths guard is removed from process_image + state_manager.init_item('source_paths', [ get_test_example_file('source.jpg') ]) await process_image(websocket_mock) websocket_mock.send_bytes.assert_called_once() - assert websocket_mock.send_bytes.call_args[0][0][:3] == b'\xff\xd8\xff' + assert websocket_mock.send_bytes.call_args[0][0][:3] == bytes([ 255, 216, 255 ]) + #TODO: remove this block once source_paths guard is removed from process_image state_manager.init_item('source_paths', None) await process_image(websocket_mock) websocket_mock.send_bytes.assert_called_once() -# TODO: refine test @pytest.mark.parametrize('video_codec, session_id', [ ('av1', 'test-process-video-av1'), ('vp8', 'test-process-video-vp8') ]) def test_process_video(video_codec : VideoCodec, session_id : str) -> None: peer_connection = rtc.create_peer_connection() @@ -93,14 +97,13 @@ def test_process_video(video_codec : VideoCodec, session_id : str) -> None: assert sender_bitrate.value == 0 assert receiver_bitrate.value == 0 - rtc.handle_remb(0, 6000000, ctypes.addressof(sender_bitrate)) - assert sender_bitrate.value == 6000 + rtc.handle_remb(0, 8000000, ctypes.addressof(sender_bitrate)) + assert sender_bitrate.value == 8000 rtc.handle_remb(0, 4000000, ctypes.addressof(receiver_bitrate)) assert receiver_bitrate.value == 4000 -# TODO: refine test @pytest.mark.anyio async def test_receive_vision_frames() -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) @@ -114,7 +117,7 @@ async def test_receive_vision_frames() -> None: }, { 'type': 'websocket.receive', - 'bytes': b'invalid' + 'bytes': 'invalid'.encode() }, { 'type': 'websocket.receive', @@ -134,10 +137,8 @@ async def test_receive_vision_frames() -> None: assert frames[0].shape == vision_frame.shape -# TODO: refine test def test_run_peer_loop() -> None: source_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - peer_connection = rtc.create_peer_connection() video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', 'vp8', 96) @@ -161,29 +162,28 @@ def test_run_peer_loop() -> None: datachannel_library_mock = MagicMock() datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ] - with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \ - patch('facefusion.apis.stream_helper.decode_video_frame', return_value = source_frame), \ - patch('facefusion.apis.stream_helper.rtc.send_video') as mock_send_video: - thread = threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True) - thread.start() - thread.join(timeout = 5.0) + with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock): + with patch('facefusion.apis.stream_helper.decode_video_frame', return_value = source_frame): + with patch('facefusion.apis.stream_helper.rtc.send_video') as mock_send_video: + thread = threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True) + thread.start() + thread.join(timeout = 5.0) assert mock_send_video.called assert len(mock_send_video.call_args[0][1]) > 0 -# TODO: refine test def test_receive_video_frames() -> None: - vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) datachannel_library_mock = MagicMock() datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ] + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1) - with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \ - patch('facefusion.apis.stream_helper.decode_video_frame', return_value = vision_frame): - receiver_thread = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue), daemon = True) - receiver_thread.start() - receiver_thread.join(timeout = 2.0) + with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock): + with patch('facefusion.apis.stream_helper.decode_video_frame', return_value = vision_frame): + receiver_thread = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue), daemon = True) + receiver_thread.start() + receiver_thread.join(timeout = 2.0) if is_linux() or is_windows(): assert create_hash(video_queue.get_nowait().tobytes()) == 'a17439db' @@ -192,86 +192,97 @@ def test_receive_video_frames() -> None: assert create_hash(video_queue.get_nowait().tobytes()) == '38d00e2a' -# TODO: refine test def test_receive_audio_frames() -> None: - audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32) datachannel_library_mock = MagicMock() datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ] + audio_frame = numpy.zeros(960 * 2).astype(numpy.float32) audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4) - with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \ - patch('facefusion.apis.stream_helper.opus_decoder.decode', return_value = audio_frame.tobytes()): - receiver_thread = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue), daemon = True) - receiver_thread.start() - audio_frame = audio_queue.get(timeout = 2.0) - receiver_thread.join(timeout = 1.0) + with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock): + with patch('facefusion.apis.stream_helper.opus_decoder.decode', return_value = audio_frame.tobytes()): + receiver_thread = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue), daemon = True) + receiver_thread.start() + audio_frame = audio_queue.get(timeout = 2.0) + receiver_thread.join(timeout = 1.0) assert audio_frame.dtype == numpy.float32 assert audio_frame.size == 960 * 2 assert audio_queue.empty() -# TODO: refine test -@pytest.mark.parametrize('video_codec', ['av1', 'vp8']) -def test_decode_video_frame(video_codec: VideoCodec) -> None: +@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) +def test_encode_and_decode_video_frame(video_codec : VideoCodec) -> None: + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + input_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + video_encoder = create_video_encoder(video_codec, (426, 226), 1000) + video_decoder = create_video_decoder(video_codec) + encode_buffer = encode_video_frame(video_codec, video_encoder, input_buffer, (426, 226), 0) + decode_buffer = decode_video_frame(video_codec, video_decoder, encode_buffer).tobytes() + + if is_linux() or is_windows(): + if video_codec == 'av1': + assert create_hash(decode_buffer) == 'c97d6d29' + + if video_codec == 'vp8': + assert create_hash(decode_buffer) == '99ef2c25' + + if is_macos(): + if video_codec == 'av1': + assert create_hash(decode_buffer) == 'eafd1fab' + + if video_codec == 'vp8': + assert create_hash(decode_buffer) == 'ff3ecb43' + + assert decode_video_frame(video_codec, video_decoder, bytes()) is None + + +@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) +def test_create_and_destroy_video_decoder(video_codec : VideoCodec) -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - frame_resolution = (vision_frame.shape[1], vision_frame.shape[0]) input_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() if video_codec == 'av1': - encode_buffer = aom_encoder.encode(aom_encoder.create(frame_resolution, 1000, 1, 0), input_buffer, frame_resolution, 0) - decode_buffer = decode_video_frame(video_codec, aom_decoder.create(8), encode_buffer).tobytes() - - if is_linux() or is_windows(): - assert create_hash(decode_buffer) == '299b6ad6' - - if is_macos(): - assert create_hash(decode_buffer) == '9f463b13' - - assert decode_video_frame('av1', aom_decoder.create(8), bytes()) is None + video_encoder = aom_encoder.create((426, 226), 1000, 1, 0) + encode_buffer = aom_encoder.encode(video_encoder, input_buffer, (426, 226), 0) if video_codec == 'vp8': - encode_buffer = vpx_encoder.encode(vpx_encoder.create(frame_resolution, 1000, 1, 0), input_buffer, frame_resolution, 0) - decode_buffer = decode_video_frame(video_codec, vpx_decoder.create(8), encode_buffer).tobytes() + video_encoder = vpx_encoder.create((426, 226), 1000, 1, 0) + encode_buffer = vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 0) - if is_linux() or is_windows(): - assert create_hash(decode_buffer) == '99ef2c25' + video_decoder = create_video_decoder(video_codec) - if is_macos(): - assert create_hash(decode_buffer) == 'ff3ecb43' + assert numpy.any(decode_video_frame(video_codec, video_decoder, encode_buffer)) - assert decode_video_frame('vp8', vpx_decoder.create(8), bytes()) is None + destroy_video_decoder(video_codec, video_decoder) + + assert decode_video_frame(video_codec, video_decoder, encode_buffer) is None @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - frame_resolution = (vision_frame.shape[1], vision_frame.shape[0]) input_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() - video_encoder = create_video_encoder(video_codec, frame_resolution, 4000) + video_encoder = create_video_encoder(video_codec, (426, 226), 4000) if video_codec == 'av1': - assert aom_encoder.encode(video_encoder, input_buffer, frame_resolution, 0) + assert aom_encoder.encode(video_encoder, input_buffer, (426, 226), 0) if video_codec == 'vp8': - assert vpx_encoder.encode(video_encoder, input_buffer, frame_resolution, 0) + assert vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 0) destroy_video_encoder(video_codec, video_encoder) if video_codec == 'av1': - assert not aom_encoder.encode(video_encoder, input_buffer, frame_resolution, 1) + assert aom_encoder.encode(video_encoder, input_buffer, (426, 226), 1) == bytes() if video_codec == 'vp8': - assert not vpx_encoder.encode(video_encoder, input_buffer, frame_resolution, 1) + assert vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 1) == bytes() @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None: - vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - frame_resolution = (vision_frame.shape[1], vision_frame.shape[0]) - - video_encoder = create_video_encoder(video_codec, frame_resolution, 4000) + video_encoder = create_video_encoder(video_codec, (426, 226), 4000) if video_codec == 'av1': assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 4000 @@ -290,24 +301,27 @@ def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None: destroy_video_encoder(video_codec, video_encoder) -# TODO: refine test -@pytest.mark.anyio -async def test_websocket_stream() -> None: - websocket_mock = AsyncMock() - websocket_mock.scope =\ +def test_destroy_stream() -> None: + peer_connection = rtc.create_peer_connection() + rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) + rtc_peer : RtcPeer =\ { - 'type': 'websocket', - 'headers': [] + 'peer_connection': peer_connection, + 'video': + { + 'sender_track': 0, + 'receiver_track': 0, + 'codec': 'vp8' + }, + 'sender_bitrate': ctypes.c_uint(0), + 'receiver_bitrate': ctypes.c_uint(0) } - websocket_mock.client_state = WebSocketState.CONNECTED - state_manager.init_item('source_paths', None) + session_id = 'test-destroy-stream' + rtc_store.init_peers(session_id) + rtc_store.get_peers(session_id).append(rtc_peer) - with patch('facefusion.apis.endpoints.stream.get_sec_websocket_protocol', return_value = None), \ - patch('facefusion.apis.endpoints.stream.extract_access_token', return_value = None), \ - patch('facefusion.apis.endpoints.stream.session_manager.find_session_id', return_value = None), \ - patch('facefusion.apis.endpoints.stream.session_context.set_session_id'): - await websocket_stream(websocket_mock) + assert destroy_stream(session_id) is True + assert rtc_store.get_peers(session_id) is None - websocket_mock.accept.assert_called_once() - websocket_mock.close.assert_called_once() + assert destroy_stream(session_id) is False diff --git a/tests/test_face_analyser.py b/tests/test_face_analyser.py index af435224..52e395d2 100644 --- a/tests/test_face_analyser.py +++ b/tests/test_face_analyser.py @@ -42,72 +42,16 @@ def before_each() -> None: face_recognizer.clear_inference_pool() -def test_get_one_face_with_retinaface() -> None: - state_manager.init_item('face_detector_model', 'retinaface') - state_manager.init_item('face_detector_size', '320x320') - state_manager.init_item('face_detector_margin', (0, 0, 0, 0)) - face_detector.pre_check() - - source_paths =\ - [ - get_test_example_file('source.jpg'), - get_test_example_file('source-80crop.jpg'), - get_test_example_file('source-70crop.jpg'), - get_test_example_file('source-60crop.jpg') - ] - - for source_path in source_paths: - source_frame = read_static_image(source_path) - many_faces = get_many_faces([ source_frame ]) - - assert len(many_faces) == 1 - - -def test_get_one_face_with_scrfd() -> None: - state_manager.init_item('face_detector_model', 'scrfd') - state_manager.init_item('face_detector_size', '320x320') - state_manager.init_item('face_detector_margin', (0, 0, 0, 0)) - face_detector.pre_check() - - source_paths =\ - [ - get_test_example_file('source.jpg'), - get_test_example_file('source-80crop.jpg'), - get_test_example_file('source-70crop.jpg'), - get_test_example_file('source-60crop.jpg') - ] - - for source_path in source_paths: - source_frame = read_static_image(source_path) - many_faces = get_many_faces([ source_frame ]) - - assert len(many_faces) == 1 - - -def test_get_one_face_with_yoloface() -> None: - state_manager.init_item('face_detector_model', 'yolo_face') - state_manager.init_item('face_detector_size', '640x640') - state_manager.init_item('face_detector_margin', (0, 0, 0, 0)) - face_detector.pre_check() - - source_paths =\ - [ - get_test_example_file('source.jpg'), - get_test_example_file('source-80crop.jpg'), - get_test_example_file('source-70crop.jpg'), - get_test_example_file('source-60crop.jpg') - ] - - for source_path in source_paths: - source_frame = read_static_image(source_path) - many_faces = get_many_faces([ source_frame ]) - - assert len(many_faces) == 1 - - -def test_get_one_face_with_yunet() -> None: - state_manager.init_item('face_detector_model', 'yunet') - state_manager.init_item('face_detector_size', '640x640') +@pytest.mark.parametrize('face_detector_model, face_detector_size', +[ + ('retinaface', '320x320'), + ('scrfd', '320x320'), + ('yolo_face', '640x640'), + ('yunet', '640x640') +]) +def test_get_one_face(face_detector_model : str, face_detector_size : str) -> None: + state_manager.init_item('face_detector_model', face_detector_model) + state_manager.init_item('face_detector_size', face_detector_size) state_manager.init_item('face_detector_margin', (0, 0, 0, 0)) face_detector.pre_check() diff --git a/tests/test_time_helper.py b/tests/test_time_helper.py index 987817b6..8d8021bd 100644 --- a/tests/test_time_helper.py +++ b/tests/test_time_helper.py @@ -4,8 +4,8 @@ from facefusion.time_helper import describe_time_ago def get_time_ago(days : int, hours : int, minutes : int) -> datetime: - previous_time = datetime.now() - timedelta(days = days, hours = hours, minutes = minutes) - return previous_time.astimezone() + time_ago = datetime.now() - timedelta(days = days, hours = hours, minutes = minutes) + return time_ago.astimezone() def test_describe_time_ago() -> None: