mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-02 10:51:39 +02:00
Refactor/rtc cleanup 3 (#1118)
* tweak rtc store and make the decision to ban trivial testing * clear todos for rtc_test, remove redundant tests * clear todos for rtc_test, remove redundant tests * break negotiation out of rtc flow, introduce create_sdp_answer and set_remote_description * add todo * move timeline control to the stream helper, clean send_audio|video_to_peers * rename some methods * fix test * introduce detect_sdp_media * introduce detect_sdp_media
This commit is contained in:
@@ -5,7 +5,7 @@ from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion import session_context, session_manager
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import add_rtc_viewer, handle_image_stream, handle_video_stream
|
||||
from facefusion.apis.stream_helper import connect_rtc, handle_image_stream, handle_video_stream
|
||||
|
||||
|
||||
async def websocket_stream(websocket : WebSocket) -> None:
|
||||
@@ -28,7 +28,7 @@ async def post_stream(request : Request) -> Response:
|
||||
|
||||
if content_type == 'application/sdp' and session_id:
|
||||
sdp_offer = await request.body()
|
||||
sdp_answer = add_rtc_viewer(session_id, sdp_offer.decode())
|
||||
sdp_answer = connect_rtc(session_id, sdp_offer.decode())
|
||||
|
||||
if sdp_answer:
|
||||
return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import queue # TODO: try deque
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Optional, Tuple, cast, get_args
|
||||
|
||||
@@ -14,7 +15,7 @@ from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encod
|
||||
from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer
|
||||
from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
|
||||
from facefusion.types import PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
|
||||
|
||||
|
||||
# TODO: refine this method
|
||||
@@ -46,7 +47,7 @@ async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
audio_temp = numpy.array([], dtype = numpy.float32)
|
||||
|
||||
vision_frame_queue.put(first_vision_frame)
|
||||
rtc_store.create_rtc_peers(session_id)
|
||||
rtc_store.init_peers(session_id)
|
||||
|
||||
event_loop = asyncio.get_running_loop()
|
||||
|
||||
@@ -80,7 +81,7 @@ async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
await video_encode_task
|
||||
await audio_encode_task
|
||||
|
||||
rtc_store.destroy_rtc_peers(session_id)
|
||||
rtc_store.delete_peers(session_id)
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
@@ -111,24 +112,17 @@ async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
|
||||
|
||||
# TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop
|
||||
def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if rtc_peers is not None:
|
||||
payload_types = rtc.parse_sdp_payload_types(sdp_offer)
|
||||
sdp_media = rtc.detect_sdp_media(sdp_offer)
|
||||
peer_connection : PeerConnection = rtc.create_peer_connection()
|
||||
audio_codec : AudioCodec = 'opus'
|
||||
audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, payload_types.get(audio_codec, 111))
|
||||
rtc.set_remote_description(peer_connection, sdp_offer)
|
||||
|
||||
#TODO: Fix me via resolve method
|
||||
video_codec : VideoCodec = 'av1'
|
||||
if payload_types.get('av1'):
|
||||
video_codec = 'av1'
|
||||
if payload_types.get('vp8'):
|
||||
video_codec = 'vp8'
|
||||
|
||||
video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', video_codec, payload_types.get(video_codec, 96))
|
||||
local_sdp = rtc.negotiate_sdp_answer(peer_connection, sdp_offer)
|
||||
audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', sdp_media.get('audio').get('codec'), sdp_media.get('audio').get('payload_type'))
|
||||
video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', sdp_media.get('video').get('codec'), sdp_media.get('video').get('payload_type'))
|
||||
local_sdp = rtc.create_sdp_answer(peer_connection)
|
||||
|
||||
if local_sdp:
|
||||
rtc_peer : RtcPeer =\
|
||||
@@ -159,13 +153,15 @@ def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]],
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if output_frame_buffer and rtc_peers:
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer)
|
||||
video_timestamp = int(time.monotonic() * 90000)
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp)
|
||||
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
#TODO: we are not using continue as control flow in the project
|
||||
continue
|
||||
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
@@ -192,13 +188,15 @@ def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]],
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if output_frame_buffer and rtc_peers:
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer)
|
||||
video_timestamp = int(time.monotonic() * 90000)
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp)
|
||||
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
# TODO: we are not using continue as control flow in the project
|
||||
continue
|
||||
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
@@ -219,7 +217,7 @@ def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], sessi
|
||||
|
||||
while audio_chunk: # TODO: improve this condition with b''
|
||||
audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960)
|
||||
rtc_peers = rtc_store.get_rtc_peers(session_id)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if audio_buffer and rtc_peers:
|
||||
rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp)
|
||||
|
||||
@@ -179,9 +179,6 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.rtcSetRemoteDescription.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p ]
|
||||
library.rtcSetRemoteDescription.restype = ctypes.c_int
|
||||
|
||||
library.rtcAddTrack.argtypes = [ ctypes.c_int, ctypes.c_char_p ]
|
||||
library.rtcAddTrack.restype = ctypes.c_int
|
||||
|
||||
library.rtcAddTrackEx.restype = ctypes.c_int
|
||||
|
||||
library.rtcSendMessage.argtypes = [ ctypes.c_int, ctypes.c_void_p, ctypes.c_int ]
|
||||
@@ -205,23 +202,8 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.rtcGetLocalDescription.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.c_int ]
|
||||
library.rtcGetLocalDescription.restype = ctypes.c_int
|
||||
|
||||
library.rtcSetLocalDescription.argtypes = [ ctypes.c_int, ctypes.c_char_p ]
|
||||
library.rtcSetLocalDescription.restype = ctypes.c_int
|
||||
|
||||
library.rtcSetOpusPacketizer.restype = ctypes.c_int
|
||||
|
||||
library.rtcSetUserPointer.argtypes = [ ctypes.c_int, ctypes.c_void_p ]
|
||||
library.rtcSetUserPointer.restype = None
|
||||
|
||||
library.rtcSetLocalDescriptionCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p) ]
|
||||
library.rtcSetLocalDescriptionCallback.restype = ctypes.c_int
|
||||
|
||||
library.rtcSetGatheringStateChangeCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) ]
|
||||
library.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int
|
||||
|
||||
library.rtcSetStateChangeCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p) ]
|
||||
library.rtcSetStateChangeCallback.restype = ctypes.c_int
|
||||
|
||||
return library
|
||||
|
||||
|
||||
|
||||
+56
-42
@@ -1,9 +1,8 @@
|
||||
import ctypes
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from facefusion.libraries import datachannel as datachannel_module
|
||||
from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec
|
||||
from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpMedia, SdpOffer, VideoCodec
|
||||
|
||||
|
||||
def create_peer_connection() -> PeerConnection:
|
||||
@@ -12,6 +11,7 @@ def create_peer_connection() -> PeerConnection:
|
||||
|
||||
rtc_configuration.enableIceUdpMux = True
|
||||
rtc_configuration.forceMediaTransport = True
|
||||
rtc_configuration.disableAutoNegotiation = True
|
||||
|
||||
return datachannel_library.rtcCreatePeerConnection(ctypes.byref(rtc_configuration))
|
||||
|
||||
@@ -28,9 +28,9 @@ def create_sdp_offer(peer_connection : PeerConnection) -> Optional[SdpOffer]:
|
||||
return None
|
||||
|
||||
|
||||
def negotiate_sdp_answer(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
|
||||
def create_sdp_answer(peer_connection : PeerConnection) -> Optional[SdpAnswer]:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer')
|
||||
datachannel_library.rtcSetLocalDescription(peer_connection, b'answer')
|
||||
|
||||
sdp_buffer = ctypes.create_string_buffer(8192)
|
||||
|
||||
@@ -40,40 +40,43 @@ def negotiate_sdp_answer(peer_connection : PeerConnection, sdp_offer : SdpOffer)
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs revision
|
||||
def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_pts : int) -> None:
|
||||
def set_remote_description(peer_connection : PeerConnection, sdp_offer : SdpOffer) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
if rtc_peers:
|
||||
timestamp = audio_pts & 0xFFFFFFFF
|
||||
send_buffer = ctypes.create_string_buffer(audio_buffer)
|
||||
send_total = len(audio_buffer)
|
||||
|
||||
for rtc_peer in rtc_peers:
|
||||
audio_track_id = rtc_peer.get('audio_track')
|
||||
|
||||
if audio_track_id and datachannel_library.rtcIsOpen(audio_track_id):
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(audio_track_id, timestamp)
|
||||
datachannel_library.rtcSendMessage(audio_track_id, send_buffer, send_total)
|
||||
datachannel_library.rtcSetRemoteDescription(peer_connection, sdp_offer.encode(), b'offer')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs revision
|
||||
def send_video_to_peers(rtc_peers : List[RtcPeer], frame_buffer : bytes) -> None:
|
||||
def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_timestamp : int) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
if rtc_peers:
|
||||
timestamp = int(time.monotonic() * 90000) & 0xFFFFFFFF
|
||||
send_buffer = ctypes.create_string_buffer(frame_buffer)
|
||||
send_total = len(frame_buffer)
|
||||
send_buffer = ctypes.create_string_buffer(audio_buffer)
|
||||
send_total = len(audio_buffer)
|
||||
|
||||
for rtc_peer in rtc_peers:
|
||||
video_track_id = rtc_peer.get('video_track')
|
||||
audio_track = rtc_peer.get('audio_track')
|
||||
|
||||
if video_track_id and datachannel_library.rtcIsOpen(video_track_id):
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(video_track_id, timestamp)
|
||||
datachannel_library.rtcSendMessage(video_track_id, send_buffer, send_total)
|
||||
if datachannel_library.rtcIsOpen(audio_track):
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(audio_track, audio_timestamp)
|
||||
datachannel_library.rtcSendMessage(audio_track, send_buffer, send_total)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def send_video_to_peers(rtc_peers : List[RtcPeer], video_buffer : bytes, video_timestamp : int) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
if rtc_peers:
|
||||
send_buffer = ctypes.create_string_buffer(video_buffer)
|
||||
send_total = len(video_buffer)
|
||||
|
||||
for rtc_peer in rtc_peers:
|
||||
video_track = rtc_peer.get('video_track')
|
||||
|
||||
if datachannel_library.rtcIsOpen(video_track):
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(video_track, video_timestamp)
|
||||
datachannel_library.rtcSendMessage(video_track, send_buffer, send_total)
|
||||
|
||||
return None
|
||||
|
||||
@@ -82,10 +85,10 @@ def delete_peers(rtc_peers : List[RtcPeer]) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
for rtc_peer in rtc_peers:
|
||||
peer_connection_id = rtc_peer.get('peer_connection')
|
||||
peer_connection = rtc_peer.get('peer_connection')
|
||||
|
||||
if peer_connection_id:
|
||||
datachannel_library.rtcDeletePeerConnection(peer_connection_id)
|
||||
if peer_connection:
|
||||
datachannel_library.rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
return None
|
||||
|
||||
@@ -172,17 +175,28 @@ def create_video_track_init(media_direction : MediaDirection, video_codec : Vide
|
||||
return ctypes.byref(track_init)
|
||||
|
||||
|
||||
#TODO: needs revision
|
||||
def parse_sdp_payload_types(sdp_offer : SdpOffer) -> Dict[str, int]:
|
||||
payload_types : Dict[str, int] = {}
|
||||
def detect_sdp_media(sdp_offer : SdpOffer) -> SdpMedia:
|
||||
sdp_media : SdpMedia = {}
|
||||
|
||||
# TODO: consider having a codec helper to resolve these
|
||||
for line in sdp_offer.splitlines():
|
||||
if line.startswith('a=rtpmap:') and 'AV1/90000' in line and not payload_types.get('av1'):
|
||||
payload_types['av1'] = int(line.split(':')[1].split(' ')[0])
|
||||
if line.startswith('a=rtpmap:') and 'VP8/90000' in line and not payload_types.get('vp8'):
|
||||
payload_types['vp8'] = int(line.split(':')[1].split(' ')[0])
|
||||
if line.startswith('a=rtpmap:') and 'opus/48000/2' in line and not payload_types.get('opus'):
|
||||
payload_types['opus'] = int(line.split(':')[1].split(' ')[0])
|
||||
if line.startswith('a=rtpmap:'):
|
||||
if 'av1/90000' in line.lower():
|
||||
sdp_media['video'] =\
|
||||
{
|
||||
'codec': 'av1',
|
||||
'payload_type': int(line.removeprefix('a=rtpmap:').split()[0])
|
||||
}
|
||||
if 'vp8/90000' in line.lower():
|
||||
sdp_media['video'] =\
|
||||
{
|
||||
'codec': 'vp8',
|
||||
'payload_type': int(line.removeprefix('a=rtpmap:').split()[0])
|
||||
}
|
||||
if 'opus/48000/2' in line.lower():
|
||||
sdp_media['audio'] =\
|
||||
{
|
||||
'codec': 'opus',
|
||||
'payload_type': int(line.removeprefix('a=rtpmap:').split()[0])
|
||||
}
|
||||
|
||||
return payload_types
|
||||
return sdp_media
|
||||
|
||||
+17
-9
@@ -7,16 +7,24 @@ from facefusion.types import RtcPeer, RtcStore, SessionId
|
||||
RTC_STORE : RtcStore = {}
|
||||
|
||||
|
||||
def get_rtc_peers(session_id : SessionId) -> List[RtcPeer]:
|
||||
return RTC_STORE.get(session_id)
|
||||
|
||||
|
||||
def create_rtc_peers(session_id : SessionId) -> None:
|
||||
def init_peers(session_id : SessionId) -> None:
|
||||
RTC_STORE[session_id] = []
|
||||
|
||||
|
||||
def destroy_rtc_peers(session_id : SessionId) -> None:
|
||||
rtc_peers = RTC_STORE.pop(session_id, None)
|
||||
def get_peers(session_id : SessionId) -> List[RtcPeer]:
|
||||
return RTC_STORE.get(session_id)
|
||||
|
||||
if rtc_peers:
|
||||
rtc.delete_peers(rtc_peers)
|
||||
|
||||
def delete_peers(session_id : SessionId) -> None:
|
||||
if session_id in RTC_STORE:
|
||||
rtc_peers = get_peers(session_id)
|
||||
|
||||
if rtc_peers:
|
||||
rtc.delete_peers(rtc_peers)
|
||||
del RTC_STORE[session_id]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
RTC_STORE.clear()
|
||||
|
||||
+18
-1
@@ -287,9 +287,26 @@ RtcPeer = TypedDict('RtcPeer',
|
||||
'video_track': RtcVideoTrack,
|
||||
'audio_track': RtcAudioTrack,
|
||||
})
|
||||
|
||||
RtcStore : TypeAlias = Dict[SessionId, List[RtcPeer]]
|
||||
|
||||
SdpAudioMedia = TypedDict('SdpAudioMedia',
|
||||
{
|
||||
'codec': AudioCodec,
|
||||
'payload_type': int
|
||||
})
|
||||
|
||||
SdpVideoMedia = TypedDict('SdpVideoMedia',
|
||||
{
|
||||
'codec': VideoCodec,
|
||||
'payload_type': int
|
||||
})
|
||||
|
||||
SdpMedia = TypedDict('SdpMedia',
|
||||
{
|
||||
'video': NotRequired[SdpVideoMedia],
|
||||
'audio': NotRequired[SdpAudioMedia]
|
||||
})
|
||||
|
||||
ModelOptions : TypeAlias = Dict[str, Any]
|
||||
ModelSet : TypeAlias = Dict[str, ModelOptions]
|
||||
ModelInitializer : TypeAlias = NDArray[Any]
|
||||
|
||||
@@ -51,7 +51,7 @@ def create_event() -> threading.Event:
|
||||
|
||||
|
||||
@pytest.mark.helper
|
||||
def set_event(session_id : str, frame_buffer : bytes, event : threading.Event) -> None:
|
||||
def set_event(session_id : str, media_buffer : bytes, timestamp : int, event : threading.Event) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
|
||||
+40
-55
@@ -2,8 +2,9 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from facefusion import rtc, state_manager
|
||||
from facefusion import state_manager
|
||||
from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
|
||||
from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, detect_sdp_media, send_audio_to_peers, send_video_to_peers, set_remote_description
|
||||
from facefusion.types import RtcPeer
|
||||
|
||||
|
||||
@@ -17,18 +18,18 @@ def before_all() -> None:
|
||||
|
||||
|
||||
def test_create_peer_connection() -> None:
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
peer_connection = create_peer_connection()
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
assert peer_connection > 0
|
||||
assert peer_connection
|
||||
assert datachannel_library.rtcDeletePeerConnection(peer_connection) == 0
|
||||
|
||||
|
||||
def test_create_sdp_offer() -> None:
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
sdp_offer = rtc.create_sdp_offer(peer_connection)
|
||||
sender_peer_connection = create_peer_connection()
|
||||
add_video_track(sender_peer_connection, 'sendonly', 'vp8', 96)
|
||||
add_audio_track(sender_peer_connection, 'sendonly', 'opus', 111)
|
||||
sdp_offer = create_sdp_offer(sender_peer_connection)
|
||||
|
||||
assert 'm=video' in sdp_offer
|
||||
assert 'VP8/90000' in sdp_offer
|
||||
@@ -37,21 +38,22 @@ def test_create_sdp_offer() -> None:
|
||||
assert 'opus/48000/2' in sdp_offer
|
||||
assert 'a=ssrc:43 cname:audio' in sdp_offer
|
||||
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(sender_peer_connection)
|
||||
|
||||
|
||||
def test_negotiate_sdp_answer() -> None:
|
||||
def test_create_sdp_answer() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
sender_connection = rtc.create_peer_connection()
|
||||
rtc.add_video_track(sender_connection, 'sendonly', 'vp8', 96)
|
||||
rtc.add_audio_track(sender_connection, 'sendonly', 'opus', 111)
|
||||
sdp_offer = rtc.create_sdp_offer(sender_connection)
|
||||
sender_peer_connection = create_peer_connection()
|
||||
add_video_track(sender_peer_connection, 'sendonly', 'vp8', 96)
|
||||
add_audio_track(sender_peer_connection, 'sendonly', 'opus', 111)
|
||||
sdp_offer = create_sdp_offer(sender_peer_connection)
|
||||
|
||||
receiver_connection = rtc.create_peer_connection()
|
||||
rtc.add_video_track(receiver_connection, 'recvonly', 'vp8', 96)
|
||||
rtc.add_audio_track(receiver_connection, 'recvonly', 'opus', 111)
|
||||
sdp_answer = rtc.negotiate_sdp_answer(receiver_connection, sdp_offer)
|
||||
receiver_peer_connection = create_peer_connection()
|
||||
set_remote_description(receiver_peer_connection, sdp_offer)
|
||||
add_video_track(receiver_peer_connection, 'recvonly', 'vp8', 96)
|
||||
add_audio_track(receiver_peer_connection, 'recvonly', 'opus', 111)
|
||||
sdp_answer = create_sdp_answer(receiver_peer_connection)
|
||||
|
||||
assert 'm=video' in sdp_answer
|
||||
assert 'VP8/90000' in sdp_answer
|
||||
@@ -61,15 +63,14 @@ def test_negotiate_sdp_answer() -> None:
|
||||
assert 'a=ssrc:43 cname:audio' in sdp_answer
|
||||
assert 'a=recvonly' in sdp_answer
|
||||
|
||||
assert datachannel_library.rtcDeletePeerConnection(sender_connection) == 0
|
||||
assert datachannel_library.rtcDeletePeerConnection(receiver_connection) == 0
|
||||
assert datachannel_library.rtcDeletePeerConnection(sender_peer_connection) == 0
|
||||
assert datachannel_library.rtcDeletePeerConnection(receiver_peer_connection) == 0
|
||||
|
||||
|
||||
# TODO: review
|
||||
def test_send_audio_to_peers() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
audio_track = rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
peer_connection = create_peer_connection()
|
||||
audio_track = add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
rtc_peers : List[RtcPeer] =\
|
||||
[
|
||||
{
|
||||
@@ -79,16 +80,15 @@ def test_send_audio_to_peers() -> None:
|
||||
}
|
||||
]
|
||||
|
||||
rtc.send_audio_to_peers(rtc_peers, bytes(960), 0)
|
||||
send_audio_to_peers(rtc_peers, bytes(960), 0)
|
||||
|
||||
datachannel_library.rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
|
||||
# TODO: review
|
||||
def test_send_video_to_peers() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
video_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
peer_connection = create_peer_connection()
|
||||
video_track = add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
rtc_peers : List[RtcPeer] =\
|
||||
[
|
||||
{
|
||||
@@ -98,14 +98,14 @@ def test_send_video_to_peers() -> None:
|
||||
}
|
||||
]
|
||||
|
||||
rtc.send_video_to_peers(rtc_peers, bytes(1024))
|
||||
send_video_to_peers(rtc_peers, bytes(1024), 0)
|
||||
|
||||
datachannel_library.rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
|
||||
def test_delete_peers() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
peer_connection = create_peer_connection()
|
||||
rtc_peers : List[RtcPeer] =\
|
||||
[
|
||||
{
|
||||
@@ -115,36 +115,21 @@ def test_delete_peers() -> None:
|
||||
}
|
||||
]
|
||||
|
||||
rtc.delete_peers(rtc_peers)
|
||||
delete_peers(rtc_peers)
|
||||
|
||||
assert datachannel_library.rtcDeletePeerConnection(peer_connection) < 0
|
||||
assert datachannel_library.rtcDeletePeerConnection(peer_connection) == -1
|
||||
|
||||
|
||||
def test_add_audio_track() -> None:
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
audio_track = rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
def test_detect_sdp_media() -> None:
|
||||
peer_connection = create_peer_connection()
|
||||
add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
sdp_offer = create_sdp_offer(peer_connection)
|
||||
sdp_payload = detect_sdp_media(sdp_offer)
|
||||
|
||||
assert audio_track > 0
|
||||
|
||||
# TODO: review
|
||||
sdp_offer = rtc.create_sdp_offer(peer_connection)
|
||||
|
||||
assert 'opus/48000/2' in sdp_offer
|
||||
assert sdp_payload.get('video').get('codec') == 'vp8'
|
||||
assert sdp_payload.get('video').get('payload_type') == 96
|
||||
assert sdp_payload.get('audio').get('codec') == 'opus'
|
||||
assert sdp_payload.get('audio').get('payload_type') == 111
|
||||
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
|
||||
def test_add_video_track() -> None:
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
video_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
|
||||
assert video_track > 0
|
||||
|
||||
# TODO: review
|
||||
sdp_offer = rtc.create_sdp_offer(peer_connection)
|
||||
|
||||
assert 'VP8/90000' in sdp_offer
|
||||
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from facefusion import rtc, rtc_store, state_manager
|
||||
from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
|
||||
from facefusion.types import RtcPeer
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
def before_all() -> None:
|
||||
state_manager.init_item('download_providers', [ 'github', 'huggingface' ])
|
||||
|
||||
datachannel_module.pre_check()
|
||||
opus_module.pre_check()
|
||||
vpx_module.pre_check()
|
||||
|
||||
|
||||
@pytest.fixture(autouse = True)
|
||||
def before_each() -> None:
|
||||
rtc_store.RTC_STORE.clear()
|
||||
|
||||
|
||||
# TODO: needs review
|
||||
def test_create_rtc_peers() -> None:
|
||||
rtc_store.create_rtc_peers('test-session')
|
||||
|
||||
assert rtc_store.RTC_STORE.get('test-session') == []
|
||||
|
||||
|
||||
# TODO: needs review
|
||||
def test_get_rtc_peers() -> None:
|
||||
assert rtc_store.get_rtc_peers('test-session') is None
|
||||
|
||||
rtc_store.create_rtc_peers('test-session')
|
||||
|
||||
assert rtc_store.get_rtc_peers('test-session') == []
|
||||
|
||||
|
||||
# TODO: needs review
|
||||
def test_destroy_rtc_peers() -> None:
|
||||
rtc_store.create_rtc_peers('test-session')
|
||||
rtc_store.destroy_rtc_peers('test-session')
|
||||
|
||||
assert rtc_store.get_rtc_peers('test-session') is None
|
||||
|
||||
|
||||
# TODO: needs review
|
||||
def test_destroy_rtc_peers_with_connections() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
rtc_store.create_rtc_peers('test-session')
|
||||
rtc_peers : List[RtcPeer] =\
|
||||
[
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video_track': 0,
|
||||
'audio_track': 0
|
||||
}
|
||||
]
|
||||
rtc_store.RTC_STORE['test-session'] = rtc_peers
|
||||
|
||||
rtc_store.destroy_rtc_peers('test-session')
|
||||
|
||||
assert rtc_store.get_rtc_peers('test-session') is None
|
||||
assert datachannel_library.rtcDeletePeerConnection(peer_connection) < 0
|
||||
@@ -209,8 +209,8 @@ def test_handle_video_stream() -> None:
|
||||
websocket.accept.assert_called_once_with(subprotocol = 'proto')
|
||||
websocket.send_text.assert_called_once_with('ready')
|
||||
websocket.close.assert_called_once()
|
||||
mock_rtc.create_rtc_peers.assert_called_once_with('session-1')
|
||||
mock_rtc.destroy_rtc_peers.assert_called_once_with('session-1')
|
||||
mock_rtc.init_peers.assert_called_once_with('session-1')
|
||||
mock_rtc.delete_peers.assert_called_once_with('session-1')
|
||||
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
|
||||
assert loop_session_id == 'session-1'
|
||||
assert loop_resolution == (64, 64)
|
||||
@@ -224,7 +224,7 @@ def test_handle_video_stream() -> None:
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
websocket.accept.assert_called_once()
|
||||
websocket.send_text.assert_not_called()
|
||||
mock_rtc.create_rtc_peers.assert_not_called()
|
||||
mock_rtc.init_peers.assert_not_called()
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
|
||||
Reference in New Issue
Block a user