diff --git a/facefusion/rtc.py b/facefusion/rtc.py index c8b66e11..856af9c2 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -1,5 +1,4 @@ import ctypes -import threading import time from typing import Dict, List, Optional @@ -7,80 +6,41 @@ from facefusion.libraries import datachannel as datachannel_module from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec -# TODO: reduce to only used params -def create_peer_connection( - ice_servers : Optional[ctypes.Array[ctypes.c_char_p]] = None, - ice_servers_count : int = 0, proxy_server : Optional[bytes] = None, - bind_address : Optional[bytes] = None, certificate_type : int = 0, - ice_transport_policy : int = 0, - enable_ice_tcp : bool = False, - enable_ice_udp_mux : bool = True, - disable_auto_negotiation : bool = False, - force_media_transport : bool = True, - port_range_begin : int = 0, - port_range_end : int = 0, - max_packet_size : int = 0, - max_message_size : int = 0) -> PeerConnection: - +def create_peer_connection() -> PeerConnection: datachannel_library = datachannel_module.create_static_library() rtc_configuration = datachannel_module.define_rtc_configuration() - rtc_configuration.iceServers = ice_servers - rtc_configuration.iceServersCount = ice_servers_count - rtc_configuration.proxyServer = proxy_server - rtc_configuration.bindAddress = bind_address - rtc_configuration.certificateType = certificate_type - rtc_configuration.iceTransportPolicy = ice_transport_policy - rtc_configuration.enableIceTcp = enable_ice_tcp - rtc_configuration.enableIceUdpMux = enable_ice_udp_mux - rtc_configuration.disableAutoNegotiation = disable_auto_negotiation - rtc_configuration.forceMediaTransport = force_media_transport - rtc_configuration.portRangeBegin = port_range_begin - rtc_configuration.portRangeEnd = port_range_end - rtc_configuration.mtu = max_packet_size - rtc_configuration.maxMessageSize = max_message_size + rtc_configuration.enableIceUdpMux = True + rtc_configuration.forceMediaTransport = True return datachannel_library.rtcCreatePeerConnection(ctypes.byref(rtc_configuration)) -# TODO: check if sleep is needed def create_sdp_offer(peer_connection : PeerConnection) -> Optional[SdpOffer]: datachannel_library = datachannel_module.create_static_library() datachannel_library.rtcSetLocalDescription(peer_connection, b'offer') - buffer_size = 16384 - buffer_string = ctypes.create_string_buffer(buffer_size) - wait_limit = time.monotonic() + 5 + sdp_buffer = ctypes.create_string_buffer(8192) - while time.monotonic() < wait_limit: - if datachannel_library.rtcGetLocalDescription(peer_connection, buffer_string, buffer_size) > 0: - return buffer_string.value.decode() - - time.sleep(0.05) - - return None - - -# TODO: sanitize sdp_offer, wrap in run_in_executor, track peer connection state -def negotiate_sdp_answer(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: - datachannel_library = datachannel_module.create_static_library() - sdp_event = threading.Event() - sdp_event_pointer = ctypes.cast(id(sdp_event), ctypes.c_void_p) - - datachannel_library.rtcSetUserPointer(peer_connection, sdp_event_pointer) - datachannel_library.rtcSetLocalDescriptionCallback(peer_connection, on_sdp_ready) - datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer') - sdp_event.wait(timeout = 5) - - sdp_buffer_size = 8192 - sdp_buffer = ctypes.create_string_buffer(sdp_buffer_size) - - if datachannel_library.rtcGetLocalDescription(peer_connection, sdp_buffer, sdp_buffer_size) > 0: + if datachannel_library.rtcGetLocalDescription(peer_connection, sdp_buffer, 8192): return sdp_buffer.value.decode() return None +def negotiate_sdp_answer(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: + datachannel_library = datachannel_module.create_static_library() + datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer') + + sdp_buffer = ctypes.create_string_buffer(8192) + + if datachannel_library.rtcGetLocalDescription(peer_connection, sdp_buffer, 8192): + return sdp_buffer.value.decode() + + return None + + +#TODO: needs revision def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_pts : int) -> None: datachannel_library = datachannel_module.create_static_library() @@ -99,6 +59,7 @@ def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_p return None +#TODO: needs revision def send_video_to_peers(rtc_peers : List[RtcPeer], frame_buffer : bytes) -> None: datachannel_library = datachannel_module.create_static_library() @@ -131,9 +92,8 @@ def delete_peers(rtc_peers : List[RtcPeer]) -> None: def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDirection, audio_codec : AudioCodec, payload_type : int) -> RtcAudioTrack: datachannel_library = datachannel_module.create_static_library() - media_description = create_audio_description(media_direction, audio_codec, payload_type) - - audio_track = datachannel_library.rtcAddTrack(peer_connection, media_description) + audio_description = create_audio_description(media_direction, audio_codec, payload_type) + audio_track = datachannel_library.rtcAddTrack(peer_connection, audio_description) audio_packetizer = datachannel_module.define_rtc_packetizer_init() audio_packetizer.ssrc = 43 @@ -151,9 +111,8 @@ def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDir def add_video_track(peer_connection : PeerConnection, media_direction : MediaDirection, video_codec : VideoCodec, payload_type : int) -> RtcVideoTrack: datachannel_library = datachannel_module.create_static_library() - media_description = create_video_description(media_direction, video_codec, payload_type) - - video_track = datachannel_library.rtcAddTrack(peer_connection, media_description) + video_description = create_video_description(media_direction, video_codec, payload_type) + video_track = datachannel_library.rtcAddTrack(peer_connection, video_description) video_packetizer = datachannel_module.define_rtc_packetizer_init() video_packetizer.ssrc = 42 @@ -165,6 +124,7 @@ def add_video_track(peer_connection : PeerConnection, media_direction : MediaDir if video_codec == 'av1': video_packetizer.obuPacketization = 1 datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer)) + if video_codec == 'vp8': datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer)) @@ -174,6 +134,7 @@ def add_video_track(peer_connection : PeerConnection, media_direction : MediaDir return video_track +#TODO: needs revision def create_audio_description(media_direction : MediaDirection, audio_codec : AudioCodec, payload_type : int) -> bytes: rtp_codec = 'opus/48000/2' if audio_codec == 'opus': @@ -193,6 +154,7 @@ def create_audio_description(media_direction : MediaDirection, audio_codec : Aud return '\r\n'.join(lines).encode() +#TODO: needs revision def create_video_description(media_direction : MediaDirection, video_codec : VideoCodec, payload_type : int) -> bytes: rtp_codec = 'AV1/90000' if video_codec == 'av1': @@ -214,6 +176,7 @@ def create_video_description(media_direction : MediaDirection, video_codec : Vid return '\r\n'.join(lines).encode() +#TODO: needs revision def parse_sdp_payload_types(sdp_offer : SdpOffer) -> Dict[str, int]: payload_types : Dict[str, int] = {} @@ -227,8 +190,3 @@ def parse_sdp_payload_types(sdp_offer : SdpOffer) -> Dict[str, int]: payload_types['opus'] = int(line.split(':')[1].split(' ')[0]) return payload_types - - -@ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) -def on_sdp_ready(peer_connection : int, sdp : Optional[bytes], sdp_type : int, user_pointer : Optional[int]) -> None: - ctypes.cast(user_pointer, ctypes.py_object).value.set() diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index a629b6f3..ae970d27 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -132,7 +132,7 @@ def test_stream_video(test_client : TestClient, create_event : threading.Event, websocket.send_bytes(chr(1).encode() + source_content) websocket.receive_text() - peer_connection = rtc.create_peer_connection(disable_auto_negotiation = True) + peer_connection = rtc.create_peer_connection() rtc.add_video_track(peer_connection, 'recvonly', 'vp8', 96) rtc.add_audio_track(peer_connection, 'recvonly', 'opus', 111) sdp_offer = rtc.create_sdp_offer(peer_connection) diff --git a/tests/test_rtc.py b/tests/test_rtc.py index 88ebbd2c..c09e7f38 100644 --- a/tests/test_rtc.py +++ b/tests/test_rtc.py @@ -25,7 +25,7 @@ def test_create_peer_connection() -> None: def test_create_sdp_offer() -> None: - peer_connection = rtc.create_peer_connection(disable_auto_negotiation = True) + peer_connection = rtc.create_peer_connection() rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) sdp_offer = rtc.create_sdp_offer(peer_connection) @@ -64,6 +64,44 @@ def test_negotiate_sdp_answer() -> None: assert datachannel_library.rtcDeletePeerConnection(receiver_connection) == 0 +# TODO: review +def test_send_audio_to_peers() -> None: + datachannel_library = datachannel_module.create_static_library() + peer_connection = rtc.create_peer_connection() + audio_track = rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) + rtc_peers : List[RtcPeer] =\ + [ + { + 'peer_connection': peer_connection, + 'video_track': 0, + 'audio_track': audio_track + } + ] + + rtc.send_audio_to_peers(rtc_peers, bytes(960), 0) + + datachannel_library.rtcDeletePeerConnection(peer_connection) + + +# TODO: review +def test_send_video_to_peers() -> None: + datachannel_library = datachannel_module.create_static_library() + peer_connection = rtc.create_peer_connection() + video_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) + rtc_peers : List[RtcPeer] =\ + [ + { + 'peer_connection': peer_connection, + 'video_track': video_track, + 'audio_track': 0 + } + ] + + rtc.send_video_to_peers(rtc_peers, bytes(1024)) + + datachannel_library.rtcDeletePeerConnection(peer_connection) + + def test_delete_peers() -> None: datachannel_library = datachannel_module.create_static_library() peer_connection = rtc.create_peer_connection()