diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index 375a07b0..6220bccf 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -1,6 +1,8 @@ import tempfile import threading +from functools import partial from typing import Iterator +from unittest.mock import patch import pytest from starlette.testclient import TestClient @@ -43,6 +45,16 @@ def test_client() -> Iterator[TestClient]: yield test_client +@pytest.fixture(scope = 'function') +def create_event() -> threading.Event: + return threading.Event() + + +@pytest.fixture(scope = 'function') +def set_event(session_id : str, frame_buffer : bytes, event : threading.Event) -> None: + event.set() + + def test_stream_image(test_client : TestClient) -> None: create_session_response = test_client.post('/session', json = { @@ -83,8 +95,7 @@ def test_stream_image(test_client : TestClient) -> None: assert create_hash(output_buffer) == '0142782f' -#TODO: this test only checks the handshake and sdp offer but no stream of video bytes -def test_stream_video(test_client : TestClient) -> None: +def test_stream_video(test_client : TestClient, create_event : threading.Event) -> None: create_session_response = test_client.post('/session', json = { 'client_version': metadata.get('version') @@ -112,23 +123,27 @@ def test_stream_video(test_client : TestClient) -> None: 'Authorization': 'Bearer ' + access_token }) - ready_event = threading.Event() - stop_event = threading.Event() - stream_thread = threading.Thread(target = open_websocket_stream, args = (test_client, [ 'access_token.' + access_token ], source_content, ready_event, stop_event)) - stream_thread.start() - ready_event.wait(timeout = 10) + with patch('facefusion.rtc_store.send_rtc_video', side_effect = partial(set_event, event = create_event)): + ready_event = threading.Event() + stop_event = threading.Event() + stream_thread = threading.Thread(target = open_websocket_stream, args = (test_client, [ 'access_token.' + access_token ], source_content, ready_event, stop_event)) + stream_thread.start() + ready_event.wait(timeout = 10) - assert ready_event.is_set() + assert ready_event.is_set() - sdp_offer = create_sdp_offer() - stream_response = test_client.post('/stream', content = sdp_offer, headers = - { - 'Authorization': 'Bearer ' + access_token, - 'Content-Type': 'application/sdp' - }) + sdp_offer = create_sdp_offer() + stream_response = test_client.post('/stream', content = sdp_offer, headers = + { + 'Authorization': 'Bearer ' + access_token, + 'Content-Type': 'application/sdp' + }) - assert stream_response.status_code == 201 - assert stream_response.text + assert stream_response.status_code == 201 - stop_event.set() - stream_thread.join(timeout = 10) + create_event.wait(timeout = 10) + + assert create_event.is_set() + + stop_event.set() + stream_thread.join(timeout = 10)