From 07c1c936afb5870a84acdb9842e77eda931dfe11 Mon Sep 17 00:00:00 2001 From: Harisreedhar <46858047+harisreedhar@users.noreply.github.com> Date: Fri, 8 May 2026 18:19:41 +0530 Subject: [PATCH] Refine RTC bindings: callback-based SDP negotiation, peer state tracking, and type cleanup (#1088) * Refine RTC bindings: callback-based SDP negotiation, peer state tracking, and type cleanup * fix lint * restore peer_connection and rename methods * remove flags, unused_methods and improve tests * fix indent --- facefusion/datachannel.py | 12 +++++++ facefusion/rtc.py | 75 ++++++++++++++++----------------------- facefusion/rtc_store.py | 30 +++++++++++----- facefusion/types.py | 15 ++++---- tests/test_rtc.py | 63 +++++++++++++++++++------------- 5 files changed, 111 insertions(+), 84 deletions(-) diff --git a/facefusion/datachannel.py b/facefusion/datachannel.py index 13e89182..b77623be 100644 --- a/facefusion/datachannel.py +++ b/facefusion/datachannel.py @@ -138,4 +138,16 @@ def init_ctypes(datachannel_library : ctypes.CDLL) -> ctypes.CDLL: datachannel_library.rtcSetOpusPacketizer.restype = ctypes.c_int + datachannel_library.rtcSetUserPointer.argtypes = [ ctypes.c_int, ctypes.c_void_p ] + datachannel_library.rtcSetUserPointer.restype = None + + datachannel_library.rtcSetLocalDescriptionCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) ] + datachannel_library.rtcSetLocalDescriptionCallback.restype = ctypes.c_int + + datachannel_library.rtcSetGatheringStateChangeCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) ] + datachannel_library.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int + + datachannel_library.rtcSetStateChangeCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) ] + datachannel_library.rtcSetStateChangeCallback.restype = ctypes.c_int + return datachannel_library diff --git a/facefusion/rtc.py b/facefusion/rtc.py index e9278f2d..6a8dbf85 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -1,4 +1,5 @@ import ctypes +import threading import time from typing import List, Optional @@ -112,48 +113,46 @@ def create_sdp(peer_connection : PeerConnection) -> Optional[SdpOffer]: return None +@ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) +def on_sdp_ready(peer_connection : int, sdp : Optional[bytes], sdp_type : int, user_pointer : Optional[int]) -> None: + ctypes.cast(user_pointer, ctypes.py_object).value.set() + + +@ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) +def on_ice_complete(peer_connection : int, state : int, user_pointer : Optional[int]) -> None: + if state == 2: + context = ctypes.cast(user_pointer, ctypes.py_object).value + context['event'].set() + + def negotiate_sdp(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: datachannel_library = create_static_datachannel_library() - datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer') - buffer_size = 16384 - buffer_string = ctypes.create_string_buffer(buffer_size) - wait_limit = time.monotonic() + 5 + sdp_event = threading.Event() + sdp_event_pointer = ctypes.cast(id(sdp_event), ctypes.c_void_p) - while time.monotonic() < wait_limit: - if datachannel_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0: - return buffer_string.value.decode() - time.sleep(0.05) + datachannel_library.rtcSetUserPointer(peer_connection, sdp_event_pointer) + datachannel_library.rtcSetLocalDescriptionCallback(peer_connection, on_sdp_ready) + datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer') + sdp_event.wait(timeout = 5) + + sdp_buffer_size = 16384 + sdp_buffer = ctypes.create_string_buffer(sdp_buffer_size) + + if datachannel_library.rtcGetLocalDescription(peer_connection, sdp_buffer, sdp_buffer_size) > 0: + return sdp_buffer.value.decode() return None -def handle_whep_offer(peers : List[RtcPeer], sdp_offer : SdpOffer) -> Optional[SdpAnswer]: - peer_connection = create_peer_connection() - audio_track = add_audio_track(peer_connection, 'sendonly') - video_track = add_video_track(peer_connection, 'sendonly') - local_sdp = negotiate_sdp(peer_connection, sdp_offer) - - if local_sdp: - rtc_peer : RtcPeer =\ - { - 'peer_connection': peer_connection, - 'video_track': video_track, - 'audio_track': audio_track - } - peers.append(rtc_peer) - - return local_sdp - - -def send_to_peers(peers : List[RtcPeer], data : bytes) -> None: +def send_to_peers(rtc_peers : List[RtcPeer], data : bytes) -> None: datachannel_library = create_static_datachannel_library() - if peers: + if rtc_peers: timestamp = int(time.monotonic() * 90000) & 0xFFFFFFFF data_buffer = ctypes.create_string_buffer(data) data_total = len(data) - for rtc_peer in peers: + for rtc_peer in rtc_peers: video_track_id = rtc_peer.get('video_track') if video_track_id and datachannel_library.rtcIsOpen(video_track_id): @@ -163,25 +162,13 @@ def send_to_peers(peers : List[RtcPeer], data : bytes) -> None: return None -def delete_peers(peers : List[RtcPeer]) -> None: +def delete_peers(rtc_peers : List[RtcPeer]) -> None: datachannel_library = create_static_datachannel_library() - for rtc_peer in peers: + for rtc_peer in rtc_peers: peer_connection_id = rtc_peer.get('peer_connection') if peer_connection_id: datachannel_library.rtcDeletePeerConnection(peer_connection_id) - peers.clear() - - -def is_peer_connected(peers : List[RtcPeer]) -> bool: - datachannel_library = create_static_datachannel_library() - - for rtc_peer in peers: - video_track_id = rtc_peer.get('video_track') - - if video_track_id and datachannel_library.rtcIsOpen(video_track_id): - return True - - return False + return None diff --git a/facefusion/rtc_store.py b/facefusion/rtc_store.py index ceb0a1a7..6573d654 100644 --- a/facefusion/rtc_store.py +++ b/facefusion/rtc_store.py @@ -1,7 +1,7 @@ from typing import List, Optional from facefusion import rtc -from facefusion.types import RtcPeer, RtcStreamStore, SdpAnswer, SdpOffer, SessionId +from facefusion.types import PeerConnection, RtcAudioTrack, RtcPeer, RtcStreamStore, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId RTC_STREAMS : RtcStreamStore = {} @@ -15,21 +15,35 @@ def create_rtc_stream(session_id : SessionId) -> None: def destroy_rtc_stream(session_id : SessionId) -> None: - peers = RTC_STREAMS.pop(session_id, None) + rtc_peers = RTC_STREAMS.pop(session_id, None) - if peers: - rtc.delete_peers(peers) + if rtc_peers: + rtc.delete_peers(rtc_peers) def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: if session_id in RTC_STREAMS: - return rtc.handle_whep_offer(RTC_STREAMS.get(session_id), sdp_offer) + peer_connection : PeerConnection = rtc.create_peer_connection() + audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly') + video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly') + local_sdp = rtc.negotiate_sdp(peer_connection, sdp_offer) + + if local_sdp: + rtc_peer : RtcPeer =\ + { + 'peer_connection': peer_connection, + 'video_track': video_track, + 'audio_track': audio_track + } + RTC_STREAMS[session_id].append(rtc_peer) + + return local_sdp return None def send_rtc_frame(session_id : SessionId, frame_data : bytes) -> None: - peers = get_rtc_stream(session_id) + rtc_peers = get_rtc_stream(session_id) - if peers: - rtc.send_to_peers(peers, frame_data) + if rtc_peers: + rtc.send_to_peers(rtc_peers, frame_data) diff --git a/facefusion/types.py b/facefusion/types.py index 7395a7bd..7e5b981e 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -272,19 +272,20 @@ RtcOfferSet = TypedDict('RtcOfferSet', 'type': str }) -RtcPeer = TypedDict('RtcPeer', -{ - 'peer_connection': int, - 'video_track': int, - 'audio_track': int -}) - RtcVideoTrack : TypeAlias = int RtcAudioTrack : TypeAlias = int PeerConnection : TypeAlias = int SdpOffer : TypeAlias = str SdpAnswer : TypeAlias = str MediaDirection : TypeAlias = Literal['sendonly', 'recvonly', 'sendrecv', 'inactive'] + +RtcPeer = TypedDict('RtcPeer', +{ + 'peer_connection': PeerConnection, + 'video_track': RtcVideoTrack, + 'audio_track': RtcAudioTrack, +}) + RtcStreamStore : TypeAlias = Dict[str, List[RtcPeer]] ModelOptions : TypeAlias = Dict[str, Any] diff --git a/tests/test_rtc.py b/tests/test_rtc.py index 167f4d30..54c9bce8 100644 --- a/tests/test_rtc.py +++ b/tests/test_rtc.py @@ -1,6 +1,9 @@ +from typing import List + import pytest from facefusion import rtc +from facefusion.types import RtcPeer @pytest.fixture(scope = 'module') @@ -22,46 +25,56 @@ def test_create_peer_connection() -> None: def test_add_audio_track() -> None: + peer_connection = rtc.create_peer_connection() + + assert rtc.add_audio_track(peer_connection, 'sendonly') > 0 + + rtc.create_static_datachannel_library().rtcDeletePeerConnection(peer_connection) + + +def test_add_video_track() -> None: + peer_connection = rtc.create_peer_connection() + + assert rtc.add_video_track(peer_connection, 'sendonly') > 0 + + rtc.create_static_datachannel_library().rtcDeletePeerConnection(peer_connection) + + +def test_negotiate_sdp() -> None: datachannel_library = rtc.create_static_datachannel_library() sender_connection = rtc.create_peer_connection() - sender_audio_track = rtc.add_audio_track(sender_connection, 'sendonly') + rtc.add_video_track(sender_connection, 'sendonly') + rtc.add_audio_track(sender_connection, 'sendonly') sdp_offer = rtc.create_sdp(sender_connection) receiver_connection = rtc.create_peer_connection() - receiver_audio_track = rtc.add_audio_track(receiver_connection, 'recvonly') + rtc.add_video_track(receiver_connection, 'recvonly') + rtc.add_audio_track(receiver_connection, 'recvonly') sdp_answer = rtc.negotiate_sdp(receiver_connection, sdp_offer) - assert sender_audio_track > 0 - assert receiver_audio_track > 0 - - assert 'm=audio' in sdp_offer + assert sdp_answer + assert 'm=video' in sdp_answer + assert 'VP8/90000' in sdp_answer assert 'm=audio' in sdp_answer - assert 'opus/48000/2' in sdp_offer assert 'opus/48000/2' in sdp_answer assert datachannel_library.rtcDeletePeerConnection(sender_connection) == 0 assert datachannel_library.rtcDeletePeerConnection(receiver_connection) == 0 -def test_add_video_track() -> None: +def test_delete_peers() -> None: datachannel_library = rtc.create_static_datachannel_library() + peer_connection = rtc.create_peer_connection() + rtc_peers : List[RtcPeer] =\ + [ + { + 'peer_connection': peer_connection, + 'video_track': 0, + 'audio_track': 0 + } + ] - sender_connection = rtc.create_peer_connection() - sender_video_track = rtc.add_video_track(sender_connection, 'sendonly') - sdp_offer = rtc.create_sdp(sender_connection) + rtc.delete_peers(rtc_peers) - receiver_connection = rtc.create_peer_connection() - receiver_video_track = rtc.add_video_track(receiver_connection, 'recvonly') - sdp_answer = rtc.negotiate_sdp(receiver_connection, sdp_offer) - - assert sender_video_track > 0 - assert receiver_video_track > 0 - - assert 'm=video' in sdp_offer - assert 'm=video' in sdp_answer - assert 'VP8/90000' in sdp_offer - assert 'VP8/90000' in sdp_answer - - assert datachannel_library.rtcDeletePeerConnection(sender_connection) == 0 - assert datachannel_library.rtcDeletePeerConnection(receiver_connection) == 0 + assert datachannel_library.rtcDeletePeerConnection(peer_connection) < 0