From dd1ded1408bb3f00b40e2c59b255bfed5c127e29 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Sat, 16 May 2026 09:06:04 +0200 Subject: [PATCH] Refactor/rtc cleanup 3 (#1118) * tweak rtc store and make the decision to ban trivial testing * clear todos for rtc_test, remove redundant tests * clear todos for rtc_test, remove redundant tests * break negotiation out of rtc flow, introduce create_sdp_answer and set_remote_description * add todo * move timeline control to the stream helper, clean send_audio|video_to_peers * rename some methods * fix test * introduce detect_sdp_media * introduce detect_sdp_media --- facefusion/apis/endpoints/stream.py | 4 +- facefusion/apis/stream_helper.py | 42 ++++++------- facefusion/libraries/datachannel.py | 18 ------ facefusion/rtc.py | 98 ++++++++++++++++------------- facefusion/rtc_store.py | 26 +++++--- facefusion/types.py | 19 +++++- tests/test_api_stream.py | 2 +- tests/test_rtc.py | 95 ++++++++++++---------------- tests/test_rtc_store.py | 66 ------------------- tests/test_stream_helper.py | 6 +- 10 files changed, 157 insertions(+), 219 deletions(-) delete mode 100644 tests/test_rtc_store.py diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 3e1b6137..fabdd159 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -5,7 +5,7 @@ from starlette.websockets import WebSocket from facefusion import session_context, session_manager from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import add_rtc_viewer, handle_image_stream, handle_video_stream +from facefusion.apis.stream_helper import connect_rtc, handle_image_stream, handle_video_stream async def websocket_stream(websocket : WebSocket) -> None: @@ -28,7 +28,7 @@ async def post_stream(request : Request) -> Response: if content_type == 'application/sdp' and session_id: sdp_offer = await request.body() - sdp_answer = add_rtc_viewer(session_id, sdp_offer.decode()) + sdp_answer = connect_rtc(session_id, sdp_offer.decode()) if sdp_answer: return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp') diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 757b9737..538046ab 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,5 +1,6 @@ import asyncio import queue # TODO: try deque +import time from collections.abc import AsyncIterator from typing import Optional, Tuple, cast, get_args @@ -14,7 +15,7 @@ from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encod from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer from facefusion.streamer import process_vision_frame -from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame +from facefusion.types import PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame # TODO: refine this method @@ -46,7 +47,7 @@ async def handle_video_stream(websocket : WebSocket) -> None: audio_temp = numpy.array([], dtype = numpy.float32) vision_frame_queue.put(first_vision_frame) - rtc_store.create_rtc_peers(session_id) + rtc_store.init_peers(session_id) event_loop = asyncio.get_running_loop() @@ -80,7 +81,7 @@ async def handle_video_stream(websocket : WebSocket) -> None: await video_encode_task await audio_encode_task - rtc_store.destroy_rtc_peers(session_id) + rtc_store.delete_peers(session_id) if websocket.client_state == WebSocketState.CONNECTED: await websocket.close() @@ -111,24 +112,17 @@ async def handle_image_stream(websocket : WebSocket) -> None: # TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop -def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: - rtc_peers = rtc_store.get_rtc_peers(session_id) +def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: + rtc_peers = rtc_store.get_peers(session_id) if rtc_peers is not None: - payload_types = rtc.parse_sdp_payload_types(sdp_offer) + sdp_media = rtc.detect_sdp_media(sdp_offer) peer_connection : PeerConnection = rtc.create_peer_connection() - audio_codec : AudioCodec = 'opus' - audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, payload_types.get(audio_codec, 111)) + rtc.set_remote_description(peer_connection, sdp_offer) - #TODO: Fix me via resolve method - video_codec : VideoCodec = 'av1' - if payload_types.get('av1'): - video_codec = 'av1' - if payload_types.get('vp8'): - video_codec = 'vp8' - - video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', video_codec, payload_types.get(video_codec, 96)) - local_sdp = rtc.negotiate_sdp_answer(peer_connection, sdp_offer) + audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', sdp_media.get('audio').get('codec'), sdp_media.get('audio').get('payload_type')) + video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', sdp_media.get('video').get('codec'), sdp_media.get('video').get('payload_type')) + local_sdp = rtc.create_sdp_answer(peer_connection) if local_sdp: rtc_peer : RtcPeer =\ @@ -159,13 +153,15 @@ def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], if output_resolution == temp_resolution: output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp) - rtc_peers = rtc_store.get_rtc_peers(session_id) + rtc_peers = rtc_store.get_peers(session_id) if output_frame_buffer and rtc_peers: - rtc.send_video_to_peers(rtc_peers, output_frame_buffer) + video_timestamp = int(time.monotonic() * 90000) + rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp) timestamp += 1 vision_frame = vision_frame_queue.get() + #TODO: we are not using continue as control flow in the project continue destroy_aom_encoder(aom_encoder) @@ -192,13 +188,15 @@ def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], if output_resolution == temp_resolution: output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp) - rtc_peers = rtc_store.get_rtc_peers(session_id) + rtc_peers = rtc_store.get_peers(session_id) if output_frame_buffer and rtc_peers: - rtc.send_video_to_peers(rtc_peers, output_frame_buffer) + video_timestamp = int(time.monotonic() * 90000) + rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp) timestamp += 1 vision_frame = vision_frame_queue.get() + # TODO: we are not using continue as control flow in the project continue destroy_vpx_encoder(vpx_encoder) @@ -219,7 +217,7 @@ def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], sessi while audio_chunk: # TODO: improve this condition with b'' audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960) - rtc_peers = rtc_store.get_rtc_peers(session_id) + rtc_peers = rtc_store.get_peers(session_id) if audio_buffer and rtc_peers: rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp) diff --git a/facefusion/libraries/datachannel.py b/facefusion/libraries/datachannel.py index 0d67ec4b..ff4b3064 100644 --- a/facefusion/libraries/datachannel.py +++ b/facefusion/libraries/datachannel.py @@ -179,9 +179,6 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcSetRemoteDescription.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p ] library.rtcSetRemoteDescription.restype = ctypes.c_int - library.rtcAddTrack.argtypes = [ ctypes.c_int, ctypes.c_char_p ] - library.rtcAddTrack.restype = ctypes.c_int - library.rtcAddTrackEx.restype = ctypes.c_int library.rtcSendMessage.argtypes = [ ctypes.c_int, ctypes.c_void_p, ctypes.c_int ] @@ -205,23 +202,8 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcGetLocalDescription.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.c_int ] library.rtcGetLocalDescription.restype = ctypes.c_int - library.rtcSetLocalDescription.argtypes = [ ctypes.c_int, ctypes.c_char_p ] - library.rtcSetLocalDescription.restype = ctypes.c_int - library.rtcSetOpusPacketizer.restype = ctypes.c_int - library.rtcSetUserPointer.argtypes = [ ctypes.c_int, ctypes.c_void_p ] - library.rtcSetUserPointer.restype = None - - library.rtcSetLocalDescriptionCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) ] - library.rtcSetLocalDescriptionCallback.restype = ctypes.c_int - - library.rtcSetGatheringStateChangeCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) ] - library.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int - - library.rtcSetStateChangeCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) ] - library.rtcSetStateChangeCallback.restype = ctypes.c_int - return library diff --git a/facefusion/rtc.py b/facefusion/rtc.py index 47e31edd..fb66e185 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -1,9 +1,8 @@ import ctypes -import time -from typing import Dict, List, Optional +from typing import List, Optional from facefusion.libraries import datachannel as datachannel_module -from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec +from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpMedia, SdpOffer, VideoCodec def create_peer_connection() -> PeerConnection: @@ -12,6 +11,7 @@ def create_peer_connection() -> PeerConnection: rtc_configuration.enableIceUdpMux = True rtc_configuration.forceMediaTransport = True + rtc_configuration.disableAutoNegotiation = True return datachannel_library.rtcCreatePeerConnection(ctypes.byref(rtc_configuration)) @@ -28,9 +28,9 @@ def create_sdp_offer(peer_connection : PeerConnection) -> Optional[SdpOffer]: return None -def negotiate_sdp_answer(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: +def create_sdp_answer(peer_connection : PeerConnection) -> Optional[SdpAnswer]: datachannel_library = datachannel_module.create_static_library() - datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer') + datachannel_library.rtcSetLocalDescription(peer_connection, b'answer') sdp_buffer = ctypes.create_string_buffer(8192) @@ -40,40 +40,43 @@ def negotiate_sdp_answer(peer_connection : PeerConnection, sdp_offer : SdpOffer) return None -#TODO: needs revision -def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_pts : int) -> None: +def set_remote_description(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> None: datachannel_library = datachannel_module.create_static_library() - - if rtc_peers: - timestamp = audio_pts & 0xFFFFFFFF - send_buffer = ctypes.create_string_buffer(audio_buffer) - send_total = len(audio_buffer) - - for rtc_peer in rtc_peers: - audio_track_id = rtc_peer.get('audio_track') - - if audio_track_id and datachannel_library.rtcIsOpen(audio_track_id): - datachannel_library.rtcSetTrackRtpTimestamp(audio_track_id, timestamp) - datachannel_library.rtcSendMessage(audio_track_id, send_buffer, send_total) + datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer') return None -#TODO: needs revision -def send_video_to_peers(rtc_peers : List[RtcPeer], frame_buffer : bytes) -> None: +def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_timestamp : int) -> None: datachannel_library = datachannel_module.create_static_library() if rtc_peers: - timestamp = int(time.monotonic() * 90000) & 0xFFFFFFFF - send_buffer = ctypes.create_string_buffer(frame_buffer) - send_total = len(frame_buffer) + send_buffer = ctypes.create_string_buffer(audio_buffer) + send_total = len(audio_buffer) for rtc_peer in rtc_peers: - video_track_id = rtc_peer.get('video_track') + audio_track = rtc_peer.get('audio_track') - if video_track_id and datachannel_library.rtcIsOpen(video_track_id): - datachannel_library.rtcSetTrackRtpTimestamp(video_track_id, timestamp) - datachannel_library.rtcSendMessage(video_track_id, send_buffer, send_total) + if datachannel_library.rtcIsOpen(audio_track): + datachannel_library.rtcSetTrackRtpTimestamp(audio_track, audio_timestamp) + datachannel_library.rtcSendMessage(audio_track, send_buffer, send_total) + + return None + + +def send_video_to_peers(rtc_peers : List[RtcPeer], video_buffer : bytes, video_timestamp : int) -> None: + datachannel_library = datachannel_module.create_static_library() + + if rtc_peers: + send_buffer = ctypes.create_string_buffer(video_buffer) + send_total = len(video_buffer) + + for rtc_peer in rtc_peers: + video_track = rtc_peer.get('video_track') + + if datachannel_library.rtcIsOpen(video_track): + datachannel_library.rtcSetTrackRtpTimestamp(video_track, video_timestamp) + datachannel_library.rtcSendMessage(video_track, send_buffer, send_total) return None @@ -82,10 +85,10 @@ def delete_peers(rtc_peers : List[RtcPeer]) -> None: datachannel_library = datachannel_module.create_static_library() for rtc_peer in rtc_peers: - peer_connection_id = rtc_peer.get('peer_connection') + peer_connection = rtc_peer.get('peer_connection') - if peer_connection_id: - datachannel_library.rtcDeletePeerConnection(peer_connection_id) + if peer_connection: + datachannel_library.rtcDeletePeerConnection(peer_connection) return None @@ -172,17 +175,28 @@ def create_video_track_init(media_direction : MediaDirection, video_codec : Vide return ctypes.byref(track_init) -#TODO: needs revision -def parse_sdp_payload_types(sdp_offer : SdpOffer) -> Dict[str, int]: - payload_types : Dict[str, int] = {} +def detect_sdp_media(sdp_offer : SdpOffer) -> SdpMedia: + sdp_media : SdpMedia = {} - # TODO: consider having a codec helper to resolve these for line in sdp_offer.splitlines(): - if line.startswith('a=rtpmap:') and 'AV1/90000' in line and not payload_types.get('av1'): - payload_types['av1'] = int(line.split(':')[1].split(' ')[0]) - if line.startswith('a=rtpmap:') and 'VP8/90000' in line and not payload_types.get('vp8'): - payload_types['vp8'] = int(line.split(':')[1].split(' ')[0]) - if line.startswith('a=rtpmap:') and 'opus/48000/2' in line and not payload_types.get('opus'): - payload_types['opus'] = int(line.split(':')[1].split(' ')[0]) + if line.startswith('a=rtpmap:'): + if 'av1/90000' in line.lower(): + sdp_media['video'] =\ + { + 'codec': 'av1', + 'payload_type': int(line.removeprefix('a=rtpmap:').split()[0]) + } + if 'vp8/90000' in line.lower(): + sdp_media['video'] =\ + { + 'codec': 'vp8', + 'payload_type': int(line.removeprefix('a=rtpmap:').split()[0]) + } + if 'opus/48000/2' in line.lower(): + sdp_media['audio'] =\ + { + 'codec': 'opus', + 'payload_type': int(line.removeprefix('a=rtpmap:').split()[0]) + } - return payload_types + return sdp_media diff --git a/facefusion/rtc_store.py b/facefusion/rtc_store.py index 41970573..2fb9bcbb 100644 --- a/facefusion/rtc_store.py +++ b/facefusion/rtc_store.py @@ -7,16 +7,24 @@ from facefusion.types import RtcPeer, RtcStore, SessionId RTC_STORE : RtcStore = {} -def get_rtc_peers(session_id : SessionId) -> List[RtcPeer]: - return RTC_STORE.get(session_id) - - -def create_rtc_peers(session_id : SessionId) -> None: +def init_peers(session_id : SessionId) -> None: RTC_STORE[session_id] = [] -def destroy_rtc_peers(session_id : SessionId) -> None: - rtc_peers = RTC_STORE.pop(session_id, None) +def get_peers(session_id : SessionId) -> List[RtcPeer]: + return RTC_STORE.get(session_id) - if rtc_peers: - rtc.delete_peers(rtc_peers) + +def delete_peers(session_id : SessionId) -> None: + if session_id in RTC_STORE: + rtc_peers = get_peers(session_id) + + if rtc_peers: + rtc.delete_peers(rtc_peers) + del RTC_STORE[session_id] + + return None + + +def clear() -> None: + RTC_STORE.clear() diff --git a/facefusion/types.py b/facefusion/types.py index 4d01b68d..f31906c9 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -287,9 +287,26 @@ RtcPeer = TypedDict('RtcPeer', 'video_track': RtcVideoTrack, 'audio_track': RtcAudioTrack, }) - RtcStore : TypeAlias = Dict[SessionId, List[RtcPeer]] +SdpAudioMedia = TypedDict('SdpAudioMedia', +{ + 'codec': AudioCodec, + 'payload_type': int +}) + +SdpVideoMedia = TypedDict('SdpVideoMedia', +{ + 'codec': VideoCodec, + 'payload_type': int +}) + +SdpMedia = TypedDict('SdpMedia', +{ + 'video': NotRequired[SdpVideoMedia], + 'audio': NotRequired[SdpAudioMedia] +}) + ModelOptions : TypeAlias = Dict[str, Any] ModelSet : TypeAlias = Dict[str, ModelOptions] ModelInitializer : TypeAlias = NDArray[Any] diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index ae970d27..bb98ba91 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -51,7 +51,7 @@ def create_event() -> threading.Event: @pytest.mark.helper -def set_event(session_id : str, frame_buffer : bytes, event : threading.Event) -> None: +def set_event(session_id : str, media_buffer : bytes, timestamp : int, event : threading.Event) -> None: event.set() diff --git a/tests/test_rtc.py b/tests/test_rtc.py index 63d024b4..544aec94 100644 --- a/tests/test_rtc.py +++ b/tests/test_rtc.py @@ -2,8 +2,9 @@ from typing import List import pytest -from facefusion import rtc, state_manager +from facefusion import state_manager from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module +from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, detect_sdp_media, send_audio_to_peers, send_video_to_peers, set_remote_description from facefusion.types import RtcPeer @@ -17,18 +18,18 @@ def before_all() -> None: def test_create_peer_connection() -> None: - peer_connection = rtc.create_peer_connection() + peer_connection = create_peer_connection() datachannel_library = datachannel_module.create_static_library() - assert peer_connection > 0 + assert peer_connection assert datachannel_library.rtcDeletePeerConnection(peer_connection) == 0 def test_create_sdp_offer() -> None: - peer_connection = rtc.create_peer_connection() - rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) - rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) - sdp_offer = rtc.create_sdp_offer(peer_connection) + sender_peer_connection = create_peer_connection() + add_video_track(sender_peer_connection, 'sendonly', 'vp8', 96) + add_audio_track(sender_peer_connection, 'sendonly', 'opus', 111) + sdp_offer = create_sdp_offer(sender_peer_connection) assert 'm=video' in sdp_offer assert 'VP8/90000' in sdp_offer @@ -37,21 +38,22 @@ def test_create_sdp_offer() -> None: assert 'opus/48000/2' in sdp_offer assert 'a=ssrc:43 cname:audio' in sdp_offer - datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) + datachannel_module.create_static_library().rtcDeletePeerConnection(sender_peer_connection) -def test_negotiate_sdp_answer() -> None: +def test_create_sdp_answer() -> None: datachannel_library = datachannel_module.create_static_library() - sender_connection = rtc.create_peer_connection() - rtc.add_video_track(sender_connection, 'sendonly', 'vp8', 96) - rtc.add_audio_track(sender_connection, 'sendonly', 'opus', 111) - sdp_offer = rtc.create_sdp_offer(sender_connection) + sender_peer_connection = create_peer_connection() + add_video_track(sender_peer_connection, 'sendonly', 'vp8', 96) + add_audio_track(sender_peer_connection, 'sendonly', 'opus', 111) + sdp_offer = create_sdp_offer(sender_peer_connection) - receiver_connection = rtc.create_peer_connection() - rtc.add_video_track(receiver_connection, 'recvonly', 'vp8', 96) - rtc.add_audio_track(receiver_connection, 'recvonly', 'opus', 111) - sdp_answer = rtc.negotiate_sdp_answer(receiver_connection, sdp_offer) + receiver_peer_connection = create_peer_connection() + set_remote_description(receiver_peer_connection, sdp_offer) + add_video_track(receiver_peer_connection, 'recvonly', 'vp8', 96) + add_audio_track(receiver_peer_connection, 'recvonly', 'opus', 111) + sdp_answer = create_sdp_answer(receiver_peer_connection) assert 'm=video' in sdp_answer assert 'VP8/90000' in sdp_answer @@ -61,15 +63,14 @@ def test_negotiate_sdp_answer() -> None: assert 'a=ssrc:43 cname:audio' in sdp_answer assert 'a=recvonly' in sdp_answer - assert datachannel_library.rtcDeletePeerConnection(sender_connection) == 0 - assert datachannel_library.rtcDeletePeerConnection(receiver_connection) == 0 + assert datachannel_library.rtcDeletePeerConnection(sender_peer_connection) == 0 + assert datachannel_library.rtcDeletePeerConnection(receiver_peer_connection) == 0 -# TODO: review def test_send_audio_to_peers() -> None: datachannel_library = datachannel_module.create_static_library() - peer_connection = rtc.create_peer_connection() - audio_track = rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) + peer_connection = create_peer_connection() + audio_track = add_audio_track(peer_connection, 'sendonly', 'opus', 111) rtc_peers : List[RtcPeer] =\ [ { @@ -79,16 +80,15 @@ def test_send_audio_to_peers() -> None: } ] - rtc.send_audio_to_peers(rtc_peers, bytes(960), 0) + send_audio_to_peers(rtc_peers, bytes(960), 0) datachannel_library.rtcDeletePeerConnection(peer_connection) -# TODO: review def test_send_video_to_peers() -> None: datachannel_library = datachannel_module.create_static_library() - peer_connection = rtc.create_peer_connection() - video_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) + peer_connection = create_peer_connection() + video_track = add_video_track(peer_connection, 'sendonly', 'vp8', 96) rtc_peers : List[RtcPeer] =\ [ { @@ -98,14 +98,14 @@ def test_send_video_to_peers() -> None: } ] - rtc.send_video_to_peers(rtc_peers, bytes(1024)) + send_video_to_peers(rtc_peers, bytes(1024), 0) datachannel_library.rtcDeletePeerConnection(peer_connection) def test_delete_peers() -> None: datachannel_library = datachannel_module.create_static_library() - peer_connection = rtc.create_peer_connection() + peer_connection = create_peer_connection() rtc_peers : List[RtcPeer] =\ [ { @@ -115,36 +115,21 @@ def test_delete_peers() -> None: } ] - rtc.delete_peers(rtc_peers) + delete_peers(rtc_peers) - assert datachannel_library.rtcDeletePeerConnection(peer_connection) < 0 + assert datachannel_library.rtcDeletePeerConnection(peer_connection) == -1 -def test_add_audio_track() -> None: - peer_connection = rtc.create_peer_connection() - audio_track = rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) +def test_detect_sdp_media() -> None: + peer_connection = create_peer_connection() + add_video_track(peer_connection, 'sendonly', 'vp8', 96) + add_audio_track(peer_connection, 'sendonly', 'opus', 111) + sdp_offer = create_sdp_offer(peer_connection) + sdp_payload = detect_sdp_media(sdp_offer) - assert audio_track > 0 - - # TODO: review - sdp_offer = rtc.create_sdp_offer(peer_connection) - - assert 'opus/48000/2' in sdp_offer + assert sdp_payload.get('video').get('codec') == 'vp8' + assert sdp_payload.get('video').get('payload_type') == 96 + assert sdp_payload.get('audio').get('codec') == 'opus' + assert sdp_payload.get('audio').get('payload_type') == 111 datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) - - -def test_add_video_track() -> None: - peer_connection = rtc.create_peer_connection() - video_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) - - assert video_track > 0 - - # TODO: review - sdp_offer = rtc.create_sdp_offer(peer_connection) - - assert 'VP8/90000' in sdp_offer - - datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) - - diff --git a/tests/test_rtc_store.py b/tests/test_rtc_store.py deleted file mode 100644 index 2e8d6c40..00000000 --- a/tests/test_rtc_store.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import List - -import pytest - -from facefusion import rtc, rtc_store, state_manager -from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module -from facefusion.types import RtcPeer - - -@pytest.fixture(scope = 'module', autouse = True) -def before_all() -> None: - state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) - - datachannel_module.pre_check() - opus_module.pre_check() - vpx_module.pre_check() - - -@pytest.fixture(autouse = True) -def before_each() -> None: - rtc_store.RTC_STORE.clear() - - -# TODO: needs review -def test_create_rtc_peers() -> None: - rtc_store.create_rtc_peers('test-session') - - assert rtc_store.RTC_STORE.get('test-session') == [] - - -# TODO: needs review -def test_get_rtc_peers() -> None: - assert rtc_store.get_rtc_peers('test-session') is None - - rtc_store.create_rtc_peers('test-session') - - assert rtc_store.get_rtc_peers('test-session') == [] - - -# TODO: needs review -def test_destroy_rtc_peers() -> None: - rtc_store.create_rtc_peers('test-session') - rtc_store.destroy_rtc_peers('test-session') - - assert rtc_store.get_rtc_peers('test-session') is None - - -# TODO: needs review -def test_destroy_rtc_peers_with_connections() -> None: - datachannel_library = datachannel_module.create_static_library() - peer_connection = rtc.create_peer_connection() - rtc_store.create_rtc_peers('test-session') - rtc_peers : List[RtcPeer] =\ - [ - { - 'peer_connection': peer_connection, - 'video_track': 0, - 'audio_track': 0 - } - ] - rtc_store.RTC_STORE['test-session'] = rtc_peers - - rtc_store.destroy_rtc_peers('test-session') - - assert rtc_store.get_rtc_peers('test-session') is None - assert datachannel_library.rtcDeletePeerConnection(peer_connection) < 0 diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py index 48832759..8b08ce3b 100644 --- a/tests/test_stream_helper.py +++ b/tests/test_stream_helper.py @@ -209,8 +209,8 @@ def test_handle_video_stream() -> None: websocket.accept.assert_called_once_with(subprotocol = 'proto') websocket.send_text.assert_called_once_with('ready') websocket.close.assert_called_once() - mock_rtc.create_rtc_peers.assert_called_once_with('session-1') - mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1') + mock_rtc.init_peers.assert_called_once_with('session-1') + mock_rtc.delete_peers.assert_called_once_with('session-1') _, loop_session_id, loop_resolution = mock_loop.call_args[0] assert loop_session_id == 'session-1' assert loop_resolution == (64, 64) @@ -224,7 +224,7 @@ def test_handle_video_stream() -> None: asyncio.run(handle_video_stream(websocket)) websocket.accept.assert_called_once() websocket.send_text.assert_not_called() - mock_rtc.create_rtc_peers.assert_not_called() + mock_rtc.init_peers.assert_not_called() websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ]) with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \