From 1562fe2feea0d917f91a66e41682b288d10e5b40 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 14 May 2026 22:41:54 +0200 Subject: [PATCH] kill the stream helper in tests --- facefusion/rtc.py | 61 ++++++++++++++++++++++++++++------------ tests/stream_helper.py | 44 ----------------------------- tests/test_api_stream.py | 39 +++++++++++-------------- 3 files changed, 60 insertions(+), 84 deletions(-) delete mode 100644 tests/stream_helper.py diff --git a/facefusion/rtc.py b/facefusion/rtc.py index f3102849..a8f0b1f6 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -145,6 +145,31 @@ def create_sdp(peer_connection : PeerConnection) -> Optional[SdpOffer]: return None +# TODO: move from testing suite helper to rtc.py - belongs here to complete the rtc flow +def create_sdp_offer() -> Optional[SdpOffer]: + datachannel_library = datachannel_module.create_static_library() + peer_connection = create_peer_connection(disable_auto_negotiation = True) + + datachannel_library.rtcAddTrack(peer_connection, build_media_description('video', 96, 'VP8/90000', 'recvonly', 0)) + datachannel_library.rtcAddTrack(peer_connection, build_media_description('audio', 111, 'opus/48000/2', 'recvonly', 1)) + datachannel_library.rtcSetLocalDescription(peer_connection, b'offer') + + buffer_size = 16384 + buffer_string = ctypes.create_string_buffer(buffer_size) + wait_limit = time.monotonic() + 5 + + while time.monotonic() < wait_limit: + if datachannel_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0: + sdp = buffer_string.value.decode() + datachannel_library.rtcDeletePeerConnection(peer_connection) + return sdp + + time.sleep(0.05) + + datachannel_library.rtcDeletePeerConnection(peer_connection) + return None + + @ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) def on_sdp_ready(peer_connection : int, sdp : Optional[bytes], sdp_type : int, user_pointer : Optional[int]) -> None: ctypes.cast(user_pointer, ctypes.py_object).value.set() @@ -178,24 +203,6 @@ def negotiate_sdp(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> Opt return None -def send_video_to_peers(rtc_peers : List[RtcPeer], frame_buffer : bytes) -> None: - datachannel_library = datachannel_module.create_static_library() - - if rtc_peers: - timestamp = int(time.monotonic() * 90000) & 0xFFFFFFFF - send_buffer = ctypes.create_string_buffer(frame_buffer) - send_total = len(frame_buffer) - - for rtc_peer in rtc_peers: - video_track_id = rtc_peer.get('video_track') - - if video_track_id and datachannel_library.rtcIsOpen(video_track_id): - datachannel_library.rtcSetTrackRtpTimestamp(video_track_id, timestamp) - datachannel_library.rtcSendMessage(video_track_id, send_buffer, send_total) - - return None - - def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_pts : int) -> None: datachannel_library = datachannel_module.create_static_library() @@ -214,6 +221,24 @@ def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_p return None +def send_video_to_peers(rtc_peers : List[RtcPeer], frame_buffer : bytes) -> None: + datachannel_library = datachannel_module.create_static_library() + + if rtc_peers: + timestamp = int(time.monotonic() * 90000) & 0xFFFFFFFF + send_buffer = ctypes.create_string_buffer(frame_buffer) + send_total = len(frame_buffer) + + for rtc_peer in rtc_peers: + video_track_id = rtc_peer.get('video_track') + + if video_track_id and datachannel_library.rtcIsOpen(video_track_id): + datachannel_library.rtcSetTrackRtpTimestamp(video_track_id, timestamp) + datachannel_library.rtcSendMessage(video_track_id, send_buffer, send_total) + + return None + + def delete_peers(rtc_peers : List[RtcPeer]) -> None: datachannel_library = datachannel_module.create_static_library() diff --git a/tests/stream_helper.py b/tests/stream_helper.py deleted file mode 100644 index 1dbcf9d3..00000000 --- a/tests/stream_helper.py +++ /dev/null @@ -1,44 +0,0 @@ -import ctypes -import threading -import time -from typing import Optional - -from starlette.testclient import TestClient - -from facefusion import rtc -from facefusion.libraries import datachannel as datachannel_module -from facefusion.types import SdpOffer - - -# TODO: remove, use rtc.create_sdp with recvonly tracks instead -def create_sdp_offer() -> Optional[SdpOffer]: - datachannel_library = datachannel_module.create_static_library() - peer_connection = rtc.create_peer_connection(disable_auto_negotiation = True) - - datachannel_library.rtcAddTrack(peer_connection, rtc.build_media_description('video', 96, 'VP8/90000', 'recvonly', 0)) - datachannel_library.rtcAddTrack(peer_connection, rtc.build_media_description('audio', 111, 'opus/48000/2', 'recvonly', 1)) - datachannel_library.rtcSetLocalDescription(peer_connection, b'offer') - - buffer_size = 16384 - buffer_string = ctypes.create_string_buffer(buffer_size) - wait_limit = time.monotonic() + 5 - - while time.monotonic() < wait_limit: - if datachannel_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0: - sdp = buffer_string.value.decode() - datachannel_library.rtcDeletePeerConnection(peer_connection) - return sdp - - time.sleep(0.05) - - datachannel_library.rtcDeletePeerConnection(peer_connection) - return None - - -# TODO: remove, inline into test_api_stream.py -def open_websocket_stream(test_client : TestClient, subprotocols : list[str], source_content : bytes, ready_event : threading.Event, stop_event : threading.Event) -> None: - with test_client.websocket_connect('/stream?mode=video', subprotocols = subprotocols) as websocket: - websocket.send_bytes(b'\x01' + source_content) - websocket.receive_text() - ready_event.set() - stop_event.wait(timeout = 15) diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index 6220bccf..1b18c55c 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -7,14 +7,13 @@ from unittest.mock import patch import pytest from starlette.testclient import TestClient -from facefusion import metadata, session_manager, state_manager +from facefusion import metadata, rtc, session_manager, state_manager from facefusion.apis import asset_store from facefusion.apis.core import create_api from facefusion.core import common_pre_check from facefusion.download import conditional_download from facefusion.hash_helper import create_hash from .assert_helper import get_test_example_file, get_test_examples_directory -from .stream_helper import create_sdp_offer, open_websocket_stream @pytest.fixture(scope = 'module', autouse = True) @@ -50,7 +49,7 @@ def create_event() -> threading.Event: return threading.Event() -@pytest.fixture(scope = 'function') +@pytest.mark.helper def set_event(session_id : str, frame_buffer : bytes, event : threading.Event) -> None: event.set() @@ -124,26 +123,22 @@ def test_stream_video(test_client : TestClient, create_event : threading.Event) }) with patch('facefusion.rtc_store.send_rtc_video', side_effect = partial(set_event, event = create_event)): - ready_event = threading.Event() - stop_event = threading.Event() - stream_thread = threading.Thread(target = open_websocket_stream, args = (test_client, [ 'access_token.' + access_token ], source_content, ready_event, stop_event)) - stream_thread.start() - ready_event.wait(timeout = 10) + with test_client.websocket_connect('/stream?mode=video', subprotocols = + [ + 'access_token.' + access_token + ]) as websocket: + websocket.send_bytes(chr(1).encode() + source_content) + websocket.receive_text() - assert ready_event.is_set() + sdp_offer = rtc.create_sdp_offer() + stream_response = test_client.post('/stream', content = sdp_offer, headers = + { + 'Authorization': 'Bearer ' + access_token, + 'Content-Type': 'application/sdp' + }) - sdp_offer = create_sdp_offer() - stream_response = test_client.post('/stream', content = sdp_offer, headers = - { - 'Authorization': 'Bearer ' + access_token, - 'Content-Type': 'application/sdp' - }) + assert stream_response.status_code == 201 - assert stream_response.status_code == 201 + create_event.wait(timeout = 10) - create_event.wait(timeout = 10) - - assert create_event.is_set() - - stop_event.set() - stream_thread.join(timeout = 10) + assert create_event.is_set()