mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-02 10:51:39 +02:00
Migrate to WHIP (#1120)
* migrate to whip part1 * migrate to whip part2 * migrate to whip part3 * migrate to whip part4 * migrate to whip/whep with bidirectional * migrate to whip/whep with bidirectional * use next library * add _next to lid datachannel files * cleanup and add todos * use internal helper rtcGetPayloadTypesForCodec * fix lint * refactor decode() * move logic to codecs * move logic to codecs * break encoders and decoders into multiple files * break encoders and decoders into multiple files * cleanup more * drop action for stream endpoints, keep type for self documentation * restore the v4 store * fix: align frame_width and frame_height to even in both collect() and read_resolution() in both decoders. --------- Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
This commit is contained in:
@@ -5,32 +5,32 @@ from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion import session_context, session_manager
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import connect_rtc, handle_image_stream, handle_video_stream
|
||||
from facefusion.apis.stream_helper import process_image, process_video
|
||||
|
||||
|
||||
async def websocket_stream(websocket : WebSocket) -> None:
|
||||
stream_mode = websocket.query_params.get('mode')
|
||||
stream_type = websocket.query_params.get('type')
|
||||
|
||||
if stream_mode == 'image':
|
||||
return await handle_image_stream(websocket)
|
||||
|
||||
if stream_mode == 'video':
|
||||
return await handle_video_stream(websocket)
|
||||
if stream_type == 'image':
|
||||
return await process_image(websocket)
|
||||
|
||||
return await websocket.close(1008)
|
||||
|
||||
|
||||
async def post_stream(request : Request) -> Response:
|
||||
stream_type = request.query_params.get('type')
|
||||
content_type = request.headers.get('content-type')
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
session_context.set_session_id(session_id)
|
||||
|
||||
if content_type == 'application/sdp' and session_id:
|
||||
sdp_offer = await request.body()
|
||||
sdp_answer = connect_rtc(session_id, sdp_offer.decode())
|
||||
|
||||
if sdp_answer:
|
||||
if stream_type == 'video':
|
||||
sdp_answer = process_video(session_id, sdp_offer.decode())
|
||||
|
||||
return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp')
|
||||
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
+307
-180
@@ -1,92 +1,24 @@
|
||||
import asyncio
|
||||
import queue # TODO: try deque
|
||||
import ctypes
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, cast, get_args
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
|
||||
from facefusion import rtc, rtc_store, session_context, session_manager, state_manager
|
||||
from facefusion import rtc, rtc_store, session_context, session_manager, state_manager, streamer
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encode_aom_buffer
|
||||
from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer
|
||||
from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
|
||||
from facefusion.audio import create_empty_audio_frame
|
||||
from facefusion.codecs import aom_decoder, aom_encoder, opus_decoder, opus_encoder, vpx_decoder, vpx_encoder
|
||||
from facefusion.libraries import datachannel as datachannel_module
|
||||
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, OpusDecoder, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
|
||||
|
||||
|
||||
# TODO: refine this method
|
||||
async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
video_codec : VideoCodec = 'av1'
|
||||
audio_codec : AudioCodec = 'opus'
|
||||
|
||||
if websocket.query_params.get('codec') in get_args(VideoCodec):
|
||||
video_codec = cast(VideoCodec, websocket.query_params.get('codec'))
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if session_id:
|
||||
stream_frames = receive_stream_frames(websocket)
|
||||
first_vision_frame : Optional[VisionFrame] = None
|
||||
|
||||
async for first_frame_type, first_frame_buffer in stream_frames:
|
||||
if first_frame_type == 1:
|
||||
first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
break
|
||||
|
||||
if numpy.any(first_vision_frame):
|
||||
resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0])
|
||||
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
|
||||
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
|
||||
audio_temp = numpy.array([], dtype = numpy.float32)
|
||||
|
||||
vision_frame_queue.put(first_vision_frame)
|
||||
rtc_store.init_peers(session_id)
|
||||
|
||||
event_loop = asyncio.get_running_loop()
|
||||
|
||||
video_encode_task = event_loop.run_in_executor(None, encode_video_loop, video_codec, vision_frame_queue, session_id, resolution)
|
||||
audio_encode_task = event_loop.run_in_executor(None, encode_audio_loop, audio_codec, audio_chunk_queue, session_id)
|
||||
await websocket.send_text('ready')
|
||||
|
||||
async for frame_type, frame_buffer in stream_frames:
|
||||
if frame_type == 1:
|
||||
vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
if vision_frame_queue.qsize():
|
||||
vision_frame_queue.get_nowait()
|
||||
vision_frame_queue.put(vision_frame)
|
||||
|
||||
if frame_type == 2:
|
||||
audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ])
|
||||
|
||||
while len(audio_temp) >= 1920:
|
||||
audio_chunk_queue.put(audio_temp[:1920].tobytes())
|
||||
audio_temp = audio_temp[1920:]
|
||||
|
||||
vision_frame_queue.put(None)
|
||||
audio_chunk_queue.put(None)
|
||||
|
||||
await video_encode_task
|
||||
await audio_encode_task
|
||||
|
||||
rtc_store.delete_peers(session_id)
|
||||
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream
|
||||
async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
#TODO: needs review
|
||||
async def process_image(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
@@ -99,7 +31,7 @@ async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
|
||||
|
||||
if numpy.any(capture_vision_frame):
|
||||
output_vision_frame = process_vision_frame(capture_vision_frame)
|
||||
output_vision_frame = streamer.process_frame(create_empty_audio_frame(), capture_vision_frame)
|
||||
is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame)
|
||||
|
||||
if is_success:
|
||||
@@ -109,108 +41,7 @@ async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
# TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop
|
||||
def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if rtc_peers is not None:
|
||||
sdp_media = rtc.detect_sdp_media(sdp_offer)
|
||||
peer_connection : PeerConnection = rtc.create_peer_connection()
|
||||
rtc.set_remote_description(peer_connection, sdp_offer)
|
||||
|
||||
audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', sdp_media.get('audio').get('codec'), sdp_media.get('audio').get('payload_type'))
|
||||
video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', sdp_media.get('video').get('codec'), sdp_media.get('video').get('payload_type'))
|
||||
local_sdp = rtc.create_sdp_answer(peer_connection)
|
||||
|
||||
if local_sdp:
|
||||
rtc_peer : RtcPeer =\
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video_track': video_track,
|
||||
'audio_track': audio_track
|
||||
}
|
||||
rtc_peers.append(rtc_peer)
|
||||
|
||||
return local_sdp
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def encode_video_loop(video_codec : VideoCodec, vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
|
||||
create_encoder = partial(create_aom_encoder, 4500, 8, 10)
|
||||
destroy_encoder = destroy_aom_encoder
|
||||
encode_buffer = encode_aom_buffer
|
||||
|
||||
if video_codec == 'vp8':
|
||||
create_encoder = partial(create_vpx_encoder, 4500, 8, 16)
|
||||
destroy_encoder = destroy_vpx_encoder # type:ignore[assignment]
|
||||
encode_buffer = encode_vpx_buffer # type:ignore[assignment]
|
||||
|
||||
encoder = create_encoder(frame_resolution)
|
||||
temp_resolution = frame_resolution
|
||||
timestamp = 0
|
||||
|
||||
vision_frame = vision_frame_queue.get()
|
||||
|
||||
while numpy.any(vision_frame) and encoder:
|
||||
output_vision_frame = process_vision_frame(vision_frame)
|
||||
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
|
||||
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_buffer(encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if output_frame_buffer and rtc_peers:
|
||||
video_timestamp = int(time.monotonic() * 90000)
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp)
|
||||
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
else:
|
||||
destroy_encoder(encoder)
|
||||
temp_resolution = output_resolution
|
||||
encoder = create_encoder(temp_resolution)
|
||||
timestamp = 0
|
||||
|
||||
if encoder:
|
||||
destroy_encoder(encoder)
|
||||
|
||||
|
||||
def encode_audio_loop(audio_codec : AudioCodec, audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
|
||||
opus_encoder = create_opus_encoder(48000, 2)
|
||||
audio_timestamp = 0
|
||||
|
||||
audio_chunk = audio_chunk_queue.get()
|
||||
|
||||
while audio_chunk: # TODO: improve this condition with b''
|
||||
audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if audio_buffer and rtc_peers:
|
||||
rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp)
|
||||
|
||||
audio_timestamp += 960
|
||||
audio_chunk = audio_chunk_queue.get()
|
||||
|
||||
if opus_encoder:
|
||||
destroy_opus_encoder(opus_encoder)
|
||||
|
||||
|
||||
# TODO: needs refinement
|
||||
async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]:
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
while websocket_event.get('type') == 'websocket.receive':
|
||||
frame_buffer = websocket_event.get('bytes') or bytes()
|
||||
|
||||
if len(frame_buffer) > 1:
|
||||
yield frame_buffer[0], frame_buffer[1:]
|
||||
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
|
||||
# TODO: needs refinement, does it receive frames or a buffer?
|
||||
#TODO: needs review
|
||||
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
@@ -222,3 +53,299 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr
|
||||
yield vision_frame
|
||||
|
||||
websocket_event = await websocket.receive()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
|
||||
video_codec : VideoCodec = 'vp8'
|
||||
av1_payload_type = rtc.get_payload_type(sdp_offer, 'av1')
|
||||
|
||||
if av1_payload_type:
|
||||
video_codec = 'av1'
|
||||
|
||||
video_payload_type = rtc.get_payload_type(sdp_offer, video_codec)
|
||||
|
||||
if not video_payload_type:
|
||||
return None
|
||||
|
||||
peer_connection : PeerConnection = rtc.create_peer_connection()
|
||||
video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', video_codec, video_payload_type)
|
||||
video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', video_codec, video_payload_type)
|
||||
|
||||
audio_codec : AudioCodec = 'opus'
|
||||
audio_payload_type = rtc.get_payload_type(sdp_offer, audio_codec)
|
||||
audio_receiver_track = None
|
||||
audio_sender_track = None
|
||||
|
||||
if audio_payload_type:
|
||||
audio_receiver_track = rtc.add_audio_track(peer_connection, 'recvonly', audio_codec, audio_payload_type)
|
||||
audio_sender_track = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, audio_payload_type)
|
||||
|
||||
rtc.set_remote_description(peer_connection, sdp_offer)
|
||||
local_sdp = rtc.create_sdp_answer(peer_connection)
|
||||
|
||||
if local_sdp:
|
||||
rtc_peer : RtcPeer =\
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video':
|
||||
{
|
||||
'sender_track': video_sender_track,
|
||||
'receiver_track': video_receiver_track,
|
||||
'codec': video_codec
|
||||
}
|
||||
}
|
||||
|
||||
if audio_receiver_track and audio_sender_track:
|
||||
rtc_peer['audio'] =\
|
||||
{
|
||||
'sender_track': audio_sender_track,
|
||||
'receiver_track': audio_receiver_track,
|
||||
'codec': audio_codec
|
||||
}
|
||||
|
||||
rtc_store.init_peers(session_id)
|
||||
rtc_store.get_peers(session_id).append(rtc_peer)
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_in_executor(None, run_peer_loop, session_id, rtc_peer)
|
||||
|
||||
return local_sdp
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
video_info = rtc_peer.get('video')
|
||||
video_codec = video_info.get('codec')
|
||||
video_decoder = create_video_decoder(video_codec)
|
||||
audio_info = rtc_peer.get('audio')
|
||||
audio_decoder = opus_decoder.create(48000, 2) if audio_info else None
|
||||
video_receive_buffer = ctypes.create_string_buffer(512 * 1024)
|
||||
audio_receive_buffer = ctypes.create_string_buffer(8 * 1024)
|
||||
|
||||
frame_buffer = poll_for_buffer(datachannel_library, video_info.get('receiver_track'), video_receive_buffer, 30.0)
|
||||
|
||||
if frame_buffer is None:
|
||||
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
|
||||
return
|
||||
|
||||
resolution = read_video_resolution(video_codec, video_decoder, frame_buffer)
|
||||
|
||||
if resolution is None:
|
||||
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
|
||||
return
|
||||
|
||||
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer, resolution)
|
||||
|
||||
if vision_frame is None:
|
||||
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
|
||||
return
|
||||
|
||||
audio_frame = create_empty_audio_frame()
|
||||
video_encoder = create_video_encoder(video_codec, resolution)
|
||||
audio_encoder = opus_encoder.create(48000, 2)
|
||||
frame_index = 0
|
||||
|
||||
while True:
|
||||
if audio_info and audio_decoder:
|
||||
audio_frame = receive_audio_frame(datachannel_library, audio_info.get('receiver_track'), audio_decoder, audio_receive_buffer)
|
||||
|
||||
output_vision_frame = streamer.process_frame(audio_frame, vision_frame)
|
||||
output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
|
||||
|
||||
if output_resolution != resolution:
|
||||
resolution = output_resolution
|
||||
destroy_video_encoder(video_codec, video_encoder)
|
||||
video_encoder = create_video_encoder(video_codec, resolution)
|
||||
|
||||
raw_vision_frame = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420)
|
||||
|
||||
if video_codec == 'av1':
|
||||
encoded_video_buffer = aom_encoder.encode(video_encoder, raw_vision_frame.tobytes(), resolution, frame_index)
|
||||
if video_codec == 'vp8':
|
||||
encoded_video_buffer = vpx_encoder.encode(video_encoder, raw_vision_frame.tobytes(), resolution, frame_index)
|
||||
|
||||
if encoded_video_buffer:
|
||||
video_timestamp = int(time.monotonic() * 90000)
|
||||
rtc.send_video(rtc_peer, encoded_video_buffer, video_timestamp)
|
||||
|
||||
if audio_encoder and audio_frame is not None and audio_frame.size > 0:
|
||||
encoded_audio_buffer = opus_encoder.encode(audio_encoder, audio_frame.tobytes(), 960)
|
||||
|
||||
if encoded_audio_buffer:
|
||||
audio_timestamp = int(time.monotonic() * 48000)
|
||||
rtc.send_audio(rtc_peer, encoded_audio_buffer, audio_timestamp)
|
||||
|
||||
frame_index += 1
|
||||
|
||||
next_frame = drain_to_latest_frame(datachannel_library, video_info.get('receiver_track'), video_codec, video_decoder, video_receive_buffer, resolution)
|
||||
|
||||
if next_frame is not None:
|
||||
vision_frame = next_frame
|
||||
continue
|
||||
|
||||
next_frame = poll_for_frame(datachannel_library, video_info.get('receiver_track'), video_codec, video_decoder, video_receive_buffer, resolution, 30.0)
|
||||
|
||||
if next_frame is None:
|
||||
break
|
||||
|
||||
vision_frame = next_frame
|
||||
|
||||
destroy_video_encoder(video_codec, video_encoder)
|
||||
opus_encoder.destroy(audio_encoder)
|
||||
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def cleanup_peer(session_id : SessionId, rtc_peer : RtcPeer, video_codec : VideoCodec, video_decoder : Optional[VpxDecoder | AomDecoder], audio_decoder : Optional[OpusDecoder]) -> None:
|
||||
if video_decoder:
|
||||
if video_codec == 'av1':
|
||||
aom_decoder.destroy(video_decoder)
|
||||
if video_codec == 'vp8':
|
||||
vpx_decoder.destroy(video_decoder)
|
||||
|
||||
if audio_decoder:
|
||||
opus_decoder.destroy(audio_decoder)
|
||||
|
||||
rtc_store.delete_peers(session_id)
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomDecoder]:
|
||||
if video_codec == 'av1':
|
||||
return aom_decoder.create()
|
||||
if video_codec == 'vp8':
|
||||
return vpx_decoder.create()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review - remove as both are the same
|
||||
def create_video_encoder(video_codec : VideoCodec, resolution : Resolution) -> Optional[VpxEncoder | AomEncoder]:
|
||||
if video_codec == 'av1':
|
||||
return aom_encoder.create(resolution, 8000, 8, 10)
|
||||
if video_codec == 'vp8':
|
||||
return vpx_encoder.create(resolution, 8000, 8, 10)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review - remove as this is a trivial helper
|
||||
def destroy_video_encoder(video_codec : VideoCodec, video_encoder : Optional[VpxEncoder | AomEncoder]) -> None:
|
||||
if video_codec == 'av1':
|
||||
aom_encoder.destroy(video_encoder)
|
||||
if video_codec == 'vp8':
|
||||
vpx_encoder.destroy(video_encoder)
|
||||
|
||||
|
||||
def read_video_resolution(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes) -> Optional[Resolution]:
|
||||
if video_codec == 'av1':
|
||||
return aom_decoder.read_resolution(video_decoder, frame_buffer)
|
||||
if video_codec == 'vp8':
|
||||
return vpx_decoder.read_resolution(video_decoder, frame_buffer)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes, frame_resolution : Resolution) -> Optional[VisionFrame]:
|
||||
output_buffer = bytes()
|
||||
|
||||
if video_codec == 'av1':
|
||||
output_buffer = aom_decoder.decode(video_decoder, frame_buffer)
|
||||
if video_codec == 'vp8':
|
||||
output_buffer = vpx_decoder.decode(video_decoder, frame_buffer)
|
||||
|
||||
if output_buffer:
|
||||
frame_width, frame_height = frame_resolution
|
||||
yuv_frame = numpy.frombuffer(output_buffer, dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
|
||||
return cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def receive_audio_frame(datachannel_library : ctypes.CDLL, audio_track : int, audio_decoder : OpusDecoder, receive_buffer : ctypes.Array[ctypes.c_char]) -> AudioFrame:
|
||||
buffer_size = ctypes.c_int(8 * 1024)
|
||||
receive_output = datachannel_library.rtcReceiveMessage(audio_track, receive_buffer, ctypes.byref(buffer_size))
|
||||
|
||||
if receive_output == 0 and buffer_size.value > 0:
|
||||
opus_buffer = receive_buffer.raw[:buffer_size.value]
|
||||
output_buffer = opus_decoder.decode(audio_decoder, opus_buffer, 960, 2)
|
||||
|
||||
if output_buffer:
|
||||
return numpy.frombuffer(output_buffer, dtype = numpy.float32)
|
||||
|
||||
return create_empty_audio_frame()
|
||||
|
||||
|
||||
def receive_video_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char]) -> Optional[bytes]:
|
||||
buffer_size = ctypes.c_int(512 * 1024)
|
||||
receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
|
||||
|
||||
if receive_output == 0 and buffer_size.value > 0:
|
||||
return receive_buffer.raw[:buffer_size.value]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def poll_for_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char], timeout : float) -> Optional[bytes]:
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer)
|
||||
|
||||
if frame_buffer is not None:
|
||||
return frame_buffer
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def poll_for_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], frame_resolution : Resolution, timeout : float) -> Optional[VisionFrame]:
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
vision_frame = try_receive_frame(datachannel_library, video_track, video_codec, video_decoder, receive_buffer, frame_resolution)
|
||||
|
||||
if vision_frame is not None:
|
||||
return vision_frame
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def try_receive_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], frame_resolution : Resolution) -> Optional[VisionFrame]:
|
||||
frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer)
|
||||
|
||||
if frame_buffer:
|
||||
return decode_video_frame(video_codec, video_decoder, frame_buffer, frame_resolution)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def drain_to_latest_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], frame_resolution : Resolution) -> Optional[VisionFrame]:
|
||||
last_vision_frame = numpy.empty(0)
|
||||
buffer_size = ctypes.c_int(512 * 1024)
|
||||
receive_output = 0
|
||||
|
||||
while receive_output == 0 and buffer_size.value > 0:
|
||||
buffer_size.value = 512 * 1024
|
||||
receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
|
||||
|
||||
if receive_output == 0 and buffer_size.value > 0:
|
||||
frame_buffer = receive_buffer.raw[:buffer_size.value]
|
||||
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer, frame_resolution)
|
||||
|
||||
if numpy.any(vision_frame):
|
||||
last_vision_frame = vision_frame
|
||||
|
||||
if numpy.any(last_vision_frame):
|
||||
return last_vision_frame
|
||||
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.libraries import aom as aom_module
|
||||
from facefusion.types import AomDecoder, Resolution
|
||||
|
||||
|
||||
def create() -> Optional[AomDecoder]:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
aom_decoder = ctypes.create_string_buffer(128)
|
||||
aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_dx_algo')
|
||||
|
||||
if aom_library.aom_codec_dec_init_ver(aom_decoder, ctypes.byref(aom_codec), None, 0, 22) == 0:
|
||||
return aom_decoder
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def decode(aom_decoder : AomDecoder, input_buffer : bytes) -> bytes:
|
||||
aom_library = aom_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
if aom_library and input_buffer:
|
||||
input_total = len(input_buffer)
|
||||
temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer)
|
||||
|
||||
if aom_library.aom_codec_decode(aom_decoder, temp_buffer, input_total, None) == 0:
|
||||
frame_pointer = aom_library.aom_codec_get_frame(aom_decoder, ctypes.byref(ctypes.c_void_p(0)))
|
||||
|
||||
if frame_pointer:
|
||||
output_buffer = collect(frame_pointer)
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def collect(frame_pointer : int) -> bytes:
|
||||
frame_width = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1
|
||||
frame_height = ctypes.c_uint.from_address(frame_pointer + 32).value & ~1
|
||||
planes_offset = frame_pointer + 64
|
||||
strides_offset = frame_pointer + 88
|
||||
output_buffer = bytes()
|
||||
|
||||
for index in range(3):
|
||||
plane_pointer = ctypes.c_void_p.from_address(planes_offset + index * 8).value
|
||||
stride = ctypes.c_int.from_address(strides_offset + index * 4).value
|
||||
plane_width = frame_width >> (index > 0)
|
||||
plane_height = frame_height >> (index > 0)
|
||||
|
||||
for row in range(plane_height):
|
||||
output_buffer += ctypes.string_at(plane_pointer + row * stride, plane_width)
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def read_resolution(aom_decoder : AomDecoder, input_buffer : bytes) -> Optional[Resolution]:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library and input_buffer:
|
||||
input_total = len(input_buffer)
|
||||
temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer)
|
||||
|
||||
if aom_library.aom_codec_decode(aom_decoder, temp_buffer, input_total, None) == 0:
|
||||
frame_pointer = aom_library.aom_codec_get_frame(aom_decoder, ctypes.byref(ctypes.c_void_p(0)))
|
||||
|
||||
if frame_pointer:
|
||||
frame_width = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1
|
||||
frame_height = ctypes.c_uint.from_address(frame_pointer + 32).value & ~1
|
||||
return frame_width, frame_height
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def destroy(aom_decoder : AomDecoder) -> None:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
aom_library.aom_codec_destroy(aom_decoder)
|
||||
@@ -6,7 +6,7 @@ from facefusion.libraries import aom as aom_module
|
||||
from facefusion.types import AomEncoder, BitRate, Resolution
|
||||
|
||||
|
||||
def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[AomEncoder]:
|
||||
def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[AomEncoder]:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
@@ -33,7 +33,7 @@ def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, fram
|
||||
return None
|
||||
|
||||
|
||||
def encode_aom_buffer(aom_encoder : AomEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes:
|
||||
def encode(aom_encoder : AomEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes:
|
||||
aom_library = aom_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
@@ -42,12 +42,12 @@ def encode_aom_buffer(aom_encoder : AomEncoder, input_buffer : bytes, frame_reso
|
||||
encode_buffer = ctypes.create_string_buffer(input_buffer)
|
||||
|
||||
if aom_library.aom_img_wrap(temp_buffer, 0x102, frame_resolution[0], frame_resolution[1], 1, encode_buffer) and aom_library.aom_codec_encode(aom_encoder, temp_buffer, frame_index, 1, 0, 1) == 0:
|
||||
output_buffer = collect_aom_buffer(aom_encoder)
|
||||
output_buffer = collect(aom_encoder)
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
def collect_aom_buffer(aom_encoder : AomEncoder) -> bytes:
|
||||
def collect(aom_encoder : AomEncoder) -> bytes:
|
||||
aom_library = aom_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
@@ -65,7 +65,7 @@ def collect_aom_buffer(aom_encoder : AomEncoder) -> bytes:
|
||||
return output_buffer
|
||||
|
||||
|
||||
def destroy_aom_encoder(aom_encoder : AomEncoder) -> None:
|
||||
def destroy(aom_encoder : AomEncoder) -> None:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
@@ -0,0 +1,36 @@
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.libraries import opus as opus_module
|
||||
from facefusion.types import OpusDecoder
|
||||
|
||||
|
||||
def create(sample_rate : int, channel_total : int) -> Optional[OpusDecoder]:
|
||||
opus_library = opus_module.create_static_library()
|
||||
|
||||
if opus_library:
|
||||
return opus_library.opus_decoder_create(sample_rate, channel_total, ctypes.byref(ctypes.c_int(0)))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def decode(opus_decoder : OpusDecoder, input_buffer : bytes, frame_size : int, channel_total : int) -> bytes:
|
||||
opus_library = opus_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
if opus_library:
|
||||
input_total = len(input_buffer)
|
||||
decode_buffer = (ctypes.c_float * (frame_size * channel_total))()
|
||||
decode_length = opus_library.opus_decode_float(opus_decoder, input_buffer, input_total, decode_buffer, frame_size, 0)
|
||||
|
||||
if decode_length:
|
||||
output_buffer = ctypes.string_at(ctypes.addressof(decode_buffer), decode_length * channel_total * ctypes.sizeof(ctypes.c_float))
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
def destroy(opus_decoder : OpusDecoder) -> None:
|
||||
opus_library = opus_module.create_static_library()
|
||||
|
||||
if opus_library:
|
||||
opus_library.opus_decoder_destroy(opus_decoder)
|
||||
@@ -5,7 +5,7 @@ from facefusion.libraries import opus as opus_module
|
||||
from facefusion.types import OpusEncoder
|
||||
|
||||
|
||||
def create_opus_encoder(sample_rate : int, channel_total : int) -> Optional[OpusEncoder]:
|
||||
def create(sample_rate : int, channel_total : int) -> Optional[OpusEncoder]:
|
||||
opus_library = opus_module.create_static_library()
|
||||
|
||||
if opus_library:
|
||||
@@ -14,7 +14,7 @@ def create_opus_encoder(sample_rate : int, channel_total : int) -> Optional[Opus
|
||||
return None
|
||||
|
||||
|
||||
def encode_opus_buffer(opus_encoder : OpusEncoder, input_buffer : bytes, frame_size : int) -> bytes:
|
||||
def encode(opus_encoder : OpusEncoder, input_buffer : bytes, frame_size : int) -> bytes:
|
||||
opus_library = opus_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
@@ -29,7 +29,7 @@ def encode_opus_buffer(opus_encoder : OpusEncoder, input_buffer : bytes, frame_s
|
||||
return output_buffer
|
||||
|
||||
|
||||
def destroy_opus_encoder(opus_encoder : OpusEncoder) -> None:
|
||||
def destroy(opus_encoder : OpusEncoder) -> None:
|
||||
opus_library = opus_module.create_static_library()
|
||||
|
||||
if opus_library:
|
||||
@@ -0,0 +1,82 @@
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.libraries import vpx as vpx_module
|
||||
from facefusion.types import Resolution, VpxDecoder
|
||||
|
||||
|
||||
def create() -> Optional[VpxDecoder]:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
vpx_decoder = ctypes.create_string_buffer(64)
|
||||
vpx_codec = ctypes.c_void_p.in_dll(vpx_library, 'vpx_codec_vp8_dx_algo')
|
||||
|
||||
if vpx_library.vpx_codec_dec_init_ver(vpx_decoder, ctypes.byref(vpx_codec), None, 0, 12) == 0:
|
||||
return vpx_decoder
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def decode(vpx_decoder : VpxDecoder, input_buffer : bytes) -> bytes:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
if vpx_library and input_buffer:
|
||||
input_total = len(input_buffer)
|
||||
temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer)
|
||||
|
||||
if vpx_library.vpx_codec_decode(vpx_decoder, temp_buffer, input_total, None, 0) == 0:
|
||||
frame_pointer = vpx_library.vpx_codec_get_frame(vpx_decoder, ctypes.byref(ctypes.c_void_p(0)))
|
||||
|
||||
if frame_pointer:
|
||||
output_buffer = collect(frame_pointer)
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
#TODO: needs review - find better name
|
||||
def collect(frame_pointer : int) -> bytes:
|
||||
frame_width = ctypes.c_uint.from_address(frame_pointer + 24).value & ~1
|
||||
frame_height = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1
|
||||
planes_offset = frame_pointer + 48
|
||||
strides_offset = frame_pointer + 80
|
||||
output_buffer = bytes()
|
||||
|
||||
for index in range(3):
|
||||
plane_pointer = ctypes.c_void_p.from_address(planes_offset + index * 8).value
|
||||
stride = ctypes.c_int.from_address(strides_offset + index * 4).value
|
||||
plane_width = frame_width >> (index > 0)
|
||||
plane_height = frame_height >> (index > 0)
|
||||
|
||||
for row in range(plane_height):
|
||||
output_buffer += ctypes.string_at(plane_pointer + row * stride, plane_width)
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def read_resolution(vpx_decoder : VpxDecoder, input_buffer : bytes) -> Optional[Resolution]:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library and input_buffer:
|
||||
input_total = len(input_buffer)
|
||||
temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer)
|
||||
|
||||
if vpx_library.vpx_codec_decode(vpx_decoder, temp_buffer, input_total, None, 0) == 0:
|
||||
frame_pointer = vpx_library.vpx_codec_get_frame(vpx_decoder, ctypes.byref(ctypes.c_void_p(0)))
|
||||
|
||||
if frame_pointer:
|
||||
frame_width = ctypes.c_uint.from_address(frame_pointer + 24).value & ~1
|
||||
frame_height = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1
|
||||
return frame_width, frame_height
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def destroy(vpx_decoder : VpxDecoder) -> None:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
vpx_library.vpx_codec_destroy(vpx_decoder)
|
||||
@@ -6,7 +6,7 @@ from facefusion.libraries import vpx as vpx_module
|
||||
from facefusion.types import BitRate, Resolution, VpxEncoder
|
||||
|
||||
|
||||
def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[VpxEncoder]:
|
||||
def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[VpxEncoder]:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
@@ -37,7 +37,7 @@ def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, fram
|
||||
return None
|
||||
|
||||
|
||||
def encode_vpx_buffer(vpx_encoder : VpxEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes:
|
||||
def encode(vpx_encoder : VpxEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
@@ -46,12 +46,12 @@ def encode_vpx_buffer(vpx_encoder : VpxEncoder, input_buffer : bytes, frame_reso
|
||||
encode_buffer = ctypes.create_string_buffer(input_buffer)
|
||||
|
||||
if vpx_library.vpx_img_wrap(temp_buffer, 0x102, frame_resolution[0], frame_resolution[1], 1, encode_buffer) and vpx_library.vpx_codec_encode(vpx_encoder, temp_buffer, frame_index, 1, 0, 1) == 0:
|
||||
output_buffer = collect_vpx_buffer(vpx_encoder)
|
||||
output_buffer = collect(vpx_encoder)
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
def collect_vpx_buffer(vpx_encoder : VpxEncoder) -> bytes:
|
||||
def collect(vpx_encoder : VpxEncoder) -> bytes:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
output_buffer = bytes()
|
||||
|
||||
@@ -69,7 +69,7 @@ def collect_vpx_buffer(vpx_encoder : VpxEncoder) -> bytes:
|
||||
return output_buffer
|
||||
|
||||
|
||||
def destroy_vpx_encoder(vpx_encoder : VpxEncoder) -> None:
|
||||
def destroy(vpx_encoder : VpxEncoder) -> None:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
@@ -119,4 +119,13 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.aom_codec_control.argtypes = [ ctypes.c_void_p, ctypes.c_int, ctypes.c_int ]
|
||||
library.aom_codec_control.restype = ctypes.c_int
|
||||
|
||||
library.aom_codec_dec_init_ver.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_long, ctypes.c_int ]
|
||||
library.aom_codec_dec_init_ver.restype = ctypes.c_int
|
||||
|
||||
library.aom_codec_decode.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p ]
|
||||
library.aom_codec_decode.restype = ctypes.c_int
|
||||
|
||||
library.aom_codec_get_frame.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ]
|
||||
library.aom_codec_get_frame.restype = ctypes.c_void_p
|
||||
|
||||
return library
|
||||
|
||||
@@ -22,8 +22,8 @@ def create_static_library_set() -> Optional[LibrarySet]:
|
||||
},
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel.hash'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel.hash')
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel_next.hash'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel_next.hash')
|
||||
},
|
||||
'ssl':
|
||||
{
|
||||
@@ -40,8 +40,8 @@ def create_static_library_set() -> Optional[LibrarySet]:
|
||||
},
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel.so'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel.so')
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel_next.so'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel_next.so')
|
||||
},
|
||||
'ssl':
|
||||
{
|
||||
@@ -62,8 +62,8 @@ def create_static_library_set() -> Optional[LibrarySet]:
|
||||
},
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel.hash'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel.hash')
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel_next.hash'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel_next.hash')
|
||||
},
|
||||
'ssl':
|
||||
{
|
||||
@@ -80,8 +80,8 @@ def create_static_library_set() -> Optional[LibrarySet]:
|
||||
},
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel.dylib'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel.dylib')
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel_next.dylib'),
|
||||
'path': resolve_relative_path('../.libraries/libdatachannel_next.dylib')
|
||||
},
|
||||
'ssl':
|
||||
{
|
||||
@@ -102,8 +102,8 @@ def create_static_library_set() -> Optional[LibrarySet]:
|
||||
},
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel.hash'),
|
||||
'path': resolve_relative_path('../.libraries/datachannel.hash')
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel_next.hash'),
|
||||
'path': resolve_relative_path('../.libraries/datachannel_next.hash')
|
||||
},
|
||||
'ssl':
|
||||
{
|
||||
@@ -120,8 +120,8 @@ def create_static_library_set() -> Optional[LibrarySet]:
|
||||
},
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel.dll'),
|
||||
'path': resolve_relative_path('../.libraries/datachannel.dll')
|
||||
'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel_next.dll'),
|
||||
'path': resolve_relative_path('../.libraries/datachannel_next.dll')
|
||||
},
|
||||
'ssl':
|
||||
{
|
||||
@@ -166,7 +166,7 @@ def create_static_library() -> Optional[ctypes.CDLL]:
|
||||
def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.rtcInitLogger.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) ]
|
||||
library.rtcInitLogger.restype = None
|
||||
library.rtcInitLogger(4, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)(0))
|
||||
library.rtcInitLogger(5, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)(0))
|
||||
|
||||
library.rtcCreatePeerConnection.restype = ctypes.c_int
|
||||
|
||||
@@ -204,6 +204,20 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
|
||||
library.rtcSetOpusPacketizer.restype = ctypes.c_int
|
||||
|
||||
library.rtcGetPayloadTypesForCodec.argtypes = [ ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int ]
|
||||
library.rtcGetPayloadTypesForCodec.restype = ctypes.c_int
|
||||
|
||||
library.rtcSetAV1Depacketizer.argtypes = [ ctypes.c_int, ctypes.c_int ]
|
||||
library.rtcSetAV1Depacketizer.restype = ctypes.c_int
|
||||
library.rtcSetVP8Depacketizer.restype = ctypes.c_int
|
||||
library.rtcSetOpusDepacketizer.restype = ctypes.c_int
|
||||
|
||||
library.rtcChainRtcpReceivingSession.argtypes = [ ctypes.c_int ]
|
||||
library.rtcChainRtcpReceivingSession.restype = ctypes.c_int
|
||||
|
||||
library.rtcReceiveMessage.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int) ]
|
||||
library.rtcReceiveMessage.restype = ctypes.c_int
|
||||
|
||||
return library
|
||||
|
||||
|
||||
|
||||
@@ -109,4 +109,13 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.opus_encoder_destroy.argtypes = [ ctypes.c_void_p ]
|
||||
library.opus_encoder_destroy.restype = None
|
||||
|
||||
library.opus_decoder_create.argtypes = [ ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int) ]
|
||||
library.opus_decoder_create.restype = ctypes.c_void_p
|
||||
|
||||
library.opus_decode_float.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int ]
|
||||
library.opus_decode_float.restype = ctypes.c_int
|
||||
|
||||
library.opus_decoder_destroy.argtypes = [ ctypes.c_void_p ]
|
||||
library.opus_decoder_destroy.restype = None
|
||||
|
||||
return library
|
||||
|
||||
@@ -119,4 +119,13 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.vpx_codec_control_.argtypes = [ ctypes.c_void_p, ctypes.c_int, ctypes.c_int ]
|
||||
library.vpx_codec_control_.restype = ctypes.c_int
|
||||
|
||||
library.vpx_codec_dec_init_ver.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_long, ctypes.c_int ]
|
||||
library.vpx_codec_dec_init_ver.restype = ctypes.c_int
|
||||
|
||||
library.vpx_codec_decode.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p, ctypes.c_long ]
|
||||
library.vpx_codec_decode.restype = ctypes.c_int
|
||||
|
||||
library.vpx_codec_get_frame.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ]
|
||||
library.vpx_codec_get_frame.restype = ctypes.c_void_p
|
||||
|
||||
return library
|
||||
|
||||
+100
-75
@@ -2,7 +2,7 @@ import ctypes
|
||||
from typing import List, Optional
|
||||
|
||||
from facefusion.libraries import datachannel as datachannel_module
|
||||
from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpMedia, SdpOffer, VideoCodec
|
||||
from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec
|
||||
|
||||
|
||||
def create_peer_connection() -> PeerConnection:
|
||||
@@ -47,36 +47,32 @@ def set_remote_description(peer_connection : PeerConnection, sdp_offer : SdpOffe
|
||||
return None
|
||||
|
||||
|
||||
def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_timestamp : int) -> None:
|
||||
def send_video(rtc_peer : RtcPeer, video_buffer : bytes, video_timestamp : int) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
if rtc_peers:
|
||||
send_buffer = ctypes.create_string_buffer(audio_buffer)
|
||||
send_total = len(audio_buffer)
|
||||
if rtc_peer.get('video'):
|
||||
video_track = rtc_peer.get('video').get('sender_track')
|
||||
|
||||
for rtc_peer in rtc_peers:
|
||||
audio_track = rtc_peer.get('audio_track')
|
||||
|
||||
if datachannel_library.rtcIsOpen(audio_track):
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(audio_track, audio_timestamp)
|
||||
datachannel_library.rtcSendMessage(audio_track, send_buffer, send_total)
|
||||
if datachannel_library.rtcIsOpen(video_track):
|
||||
send_buffer = ctypes.create_string_buffer(video_buffer)
|
||||
send_total = len(video_buffer)
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(video_track, video_timestamp)
|
||||
datachannel_library.rtcSendMessage(video_track, send_buffer, send_total)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def send_video_to_peers(rtc_peers : List[RtcPeer], video_buffer : bytes, video_timestamp : int) -> None:
|
||||
def send_audio(rtc_peer : RtcPeer, audio_buffer : bytes, audio_timestamp : int) -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
|
||||
if rtc_peers:
|
||||
send_buffer = ctypes.create_string_buffer(video_buffer)
|
||||
send_total = len(video_buffer)
|
||||
if rtc_peer.get('audio'):
|
||||
audio_track = rtc_peer.get('audio').get('sender_track')
|
||||
|
||||
for rtc_peer in rtc_peers:
|
||||
video_track = rtc_peer.get('video_track')
|
||||
|
||||
if datachannel_library.rtcIsOpen(video_track):
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(video_track, video_timestamp)
|
||||
datachannel_library.rtcSendMessage(video_track, send_buffer, send_total)
|
||||
if datachannel_library.rtcIsOpen(audio_track):
|
||||
send_buffer = ctypes.create_string_buffer(audio_buffer)
|
||||
send_total = len(audio_buffer)
|
||||
datachannel_library.rtcSetTrackRtpTimestamp(audio_track, audio_timestamp)
|
||||
datachannel_library.rtcSendMessage(audio_track, send_buffer, send_total)
|
||||
|
||||
return None
|
||||
|
||||
@@ -98,16 +94,29 @@ def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDir
|
||||
audio_track_init = create_audio_track_init(media_direction, audio_codec, payload_type)
|
||||
audio_track = datachannel_library.rtcAddTrackEx(peer_connection, audio_track_init)
|
||||
|
||||
audio_packetizer = datachannel_module.define_rtc_packetizer_init()
|
||||
audio_packetizer.ssrc = 43
|
||||
audio_packetizer.cname = b'audio'
|
||||
audio_packetizer.payloadType = payload_type
|
||||
audio_packetizer.clockRate = 48000
|
||||
if media_direction == 'sendonly':
|
||||
audio_packetizer = datachannel_module.define_rtc_packetizer_init()
|
||||
audio_packetizer.ssrc = 43
|
||||
audio_packetizer.cname = b'audio'
|
||||
audio_packetizer.payloadType = payload_type
|
||||
audio_packetizer.clockRate = 48000
|
||||
|
||||
if audio_codec == 'opus':
|
||||
datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer))
|
||||
if audio_codec == 'opus':
|
||||
datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer))
|
||||
|
||||
datachannel_library.rtcChainRtcpSrReporter(audio_track)
|
||||
datachannel_library.rtcChainRtcpSrReporter(audio_track)
|
||||
|
||||
if media_direction == 'recvonly':
|
||||
audio_depacketizer = datachannel_module.define_rtc_packetizer_init()
|
||||
audio_depacketizer.ssrc = 0
|
||||
audio_depacketizer.cname = b'audio'
|
||||
audio_depacketizer.payloadType = payload_type
|
||||
audio_depacketizer.clockRate = 48000
|
||||
|
||||
if audio_codec == 'opus':
|
||||
datachannel_library.rtcSetOpusDepacketizer(audio_track, ctypes.byref(audio_depacketizer))
|
||||
|
||||
datachannel_library.rtcChainRtcpReceivingSession(audio_track)
|
||||
|
||||
return audio_track
|
||||
|
||||
@@ -117,86 +126,102 @@ def add_video_track(peer_connection : PeerConnection, media_direction : MediaDir
|
||||
video_track_init = create_video_track_init(media_direction, video_codec, payload_type)
|
||||
video_track = datachannel_library.rtcAddTrackEx(peer_connection, video_track_init)
|
||||
|
||||
video_packetizer = datachannel_module.define_rtc_packetizer_init()
|
||||
video_packetizer.ssrc = 42
|
||||
video_packetizer.cname = b'video'
|
||||
video_packetizer.payloadType = payload_type
|
||||
video_packetizer.clockRate = 90000
|
||||
video_packetizer.maxFragmentSize = 1200
|
||||
if media_direction == 'sendonly':
|
||||
video_packetizer = datachannel_module.define_rtc_packetizer_init()
|
||||
video_packetizer.ssrc = 42
|
||||
video_packetizer.cname = b'video'
|
||||
video_packetizer.payloadType = payload_type
|
||||
video_packetizer.clockRate = 90000
|
||||
video_packetizer.maxFragmentSize = 1200
|
||||
|
||||
if video_codec == 'av1':
|
||||
video_packetizer.obuPacketization = 1
|
||||
datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer))
|
||||
if video_codec == 'av1':
|
||||
video_packetizer.obuPacketization = 1
|
||||
datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer))
|
||||
|
||||
if video_codec == 'vp8':
|
||||
datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer))
|
||||
if video_codec == 'vp8':
|
||||
datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer))
|
||||
|
||||
datachannel_library.rtcChainRtcpSrReporter(video_track)
|
||||
datachannel_library.rtcChainRtcpNackResponder(video_track, 512)
|
||||
datachannel_library.rtcChainRtcpSrReporter(video_track)
|
||||
datachannel_library.rtcChainRtcpNackResponder(video_track, 512)
|
||||
|
||||
if media_direction == 'recvonly':
|
||||
if video_codec == 'av1':
|
||||
datachannel_library.rtcSetAV1Depacketizer(video_track, 1)
|
||||
|
||||
if video_codec == 'vp8':
|
||||
video_depacketizer = datachannel_module.define_rtc_packetizer_init()
|
||||
video_depacketizer.ssrc = 0
|
||||
video_depacketizer.cname = b'video'
|
||||
video_depacketizer.payloadType = payload_type
|
||||
video_depacketizer.clockRate = 90000
|
||||
datachannel_library.rtcSetVP8Depacketizer(video_track, ctypes.byref(video_depacketizer))
|
||||
|
||||
datachannel_library.rtcChainRtcpReceivingSession(video_track)
|
||||
|
||||
return video_track
|
||||
|
||||
|
||||
def create_audio_track_init(media_direction : MediaDirection, audio_codec : AudioCodec, payload_type : int) -> RtcTrackInit:
|
||||
track_init = datachannel_module.define_rtc_track_init()
|
||||
track_init.name = b'audio'
|
||||
track_init.payloadType = payload_type
|
||||
|
||||
if media_direction == 'sendonly':
|
||||
track_init.direction = 1
|
||||
track_init.mid = b'3'
|
||||
track_init.ssrc = 43
|
||||
|
||||
if media_direction == 'recvonly':
|
||||
track_init.direction = 2
|
||||
track_init.mid = b'2'
|
||||
track_init.ssrc = 45
|
||||
|
||||
if media_direction == 'sendrecv':
|
||||
track_init.direction = 3
|
||||
track_init.mid = b'1'
|
||||
track_init.ssrc = 43
|
||||
|
||||
if audio_codec == 'opus':
|
||||
track_init.codec = 128
|
||||
|
||||
track_init.payloadType = payload_type
|
||||
track_init.ssrc = 43
|
||||
track_init.name = b'audio'
|
||||
track_init.mid = b'1'
|
||||
|
||||
return ctypes.byref(track_init)
|
||||
|
||||
|
||||
def create_video_track_init(media_direction : MediaDirection, video_codec : VideoCodec, payload_type : int) -> RtcTrackInit:
|
||||
track_init = datachannel_module.define_rtc_track_init()
|
||||
track_init.name = b'video'
|
||||
track_init.payloadType = payload_type
|
||||
|
||||
if media_direction == 'sendonly':
|
||||
track_init.direction = 1
|
||||
track_init.mid = b'1'
|
||||
track_init.ssrc = 42
|
||||
|
||||
if media_direction == 'recvonly':
|
||||
track_init.direction = 2
|
||||
track_init.mid = b'0'
|
||||
track_init.ssrc = 44
|
||||
|
||||
if media_direction == 'sendrecv':
|
||||
track_init.direction = 3
|
||||
track_init.mid = b'0'
|
||||
track_init.ssrc = 42
|
||||
|
||||
if video_codec == 'av1':
|
||||
track_init.codec = 4
|
||||
|
||||
if video_codec == 'vp8':
|
||||
track_init.codec = 1
|
||||
|
||||
track_init.payloadType = payload_type
|
||||
track_init.ssrc = 42
|
||||
track_init.name = b'video'
|
||||
track_init.mid = b'0'
|
||||
|
||||
return ctypes.byref(track_init)
|
||||
|
||||
|
||||
def detect_sdp_media(sdp_offer : SdpOffer) -> SdpMedia:
|
||||
sdp_media : SdpMedia = {}
|
||||
def get_payload_type(sdp_offer : SdpOffer, codec : AudioCodec | VideoCodec) -> int:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
payload_type_buffer = (ctypes.c_int * 16)()
|
||||
payload_type_total = datachannel_library.rtcGetPayloadTypesForCodec(sdp_offer.encode(), codec.lower().encode(), payload_type_buffer, 16)
|
||||
|
||||
for line in sdp_offer.splitlines():
|
||||
if line.startswith('a=rtpmap:'):
|
||||
if 'av1/90000' in line.lower():
|
||||
sdp_media['video'] =\
|
||||
{
|
||||
'codec': 'av1',
|
||||
'payload_type': int(line.removeprefix('a=rtpmap:').split()[0])
|
||||
}
|
||||
if 'vp8/90000' in line.lower():
|
||||
sdp_media['video'] =\
|
||||
{
|
||||
'codec': 'vp8',
|
||||
'payload_type': int(line.removeprefix('a=rtpmap:').split()[0])
|
||||
}
|
||||
if 'opus/48000/2' in line.lower():
|
||||
sdp_media['audio'] =\
|
||||
{
|
||||
'codec': 'opus',
|
||||
'payload_type': int(line.removeprefix('a=rtpmap:').split()[0])
|
||||
}
|
||||
if payload_type_total:
|
||||
return payload_type_buffer[0]
|
||||
|
||||
return sdp_media
|
||||
return 0
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import List
|
||||
from facefusion import rtc
|
||||
from facefusion.types import RtcPeer, RtcStore, SessionId
|
||||
|
||||
|
||||
RTC_STORE : RtcStore = {}
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from facefusion.content_analyser import analyse_stream
|
||||
from facefusion.ffmpeg import open_ffmpeg
|
||||
from facefusion.filesystem import is_directory
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
from facefusion.types import Fps, StreamMode, VisionFrame
|
||||
from facefusion.types import AudioFrame, Fps, StreamMode, VisionFrame
|
||||
from facefusion.vision import extract_vision_mask, read_static_images
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ def multi_process_capture(camera_capture : cv2.VideoCapture, camera_fps : Fps) -
|
||||
camera_capture.release()
|
||||
|
||||
if numpy.any(capture_frame):
|
||||
future = executor.submit(process_vision_frame, capture_frame)
|
||||
audio_frame = create_empty_audio_frame()
|
||||
future = executor.submit(process_frame, audio_frame, capture_frame)
|
||||
futures.append(future)
|
||||
|
||||
for future_done in [ future for future in futures if future.done() ]:
|
||||
@@ -44,11 +45,10 @@ def multi_process_capture(camera_capture : cv2.VideoCapture, camera_fps : Fps) -
|
||||
yield capture_deque.popleft()
|
||||
|
||||
|
||||
def process_vision_frame(target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def process_frame(stream_audio_frame : AudioFrame, stream_vision_frame : VisionFrame) -> VisionFrame:
|
||||
source_vision_frames = read_static_images(state_manager.get_item('source_paths'))
|
||||
source_audio_frame = create_empty_audio_frame()
|
||||
source_voice_frame = create_empty_audio_frame()
|
||||
temp_vision_frame = target_vision_frame.copy()
|
||||
temp_vision_frame = stream_vision_frame.copy()
|
||||
temp_vision_mask = extract_vision_mask(temp_vision_frame)
|
||||
|
||||
for processor_module in get_processors_modules(state_manager.get_item('processors')):
|
||||
@@ -58,9 +58,9 @@ def process_vision_frame(target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
temp_vision_frame, temp_vision_mask = processor_module.process_frame(
|
||||
{
|
||||
'source_vision_frames': source_vision_frames,
|
||||
'source_audio_frame': source_audio_frame,
|
||||
'source_audio_frame': stream_audio_frame,
|
||||
'source_voice_frame': source_voice_frame,
|
||||
'target_vision_frame': target_vision_frame,
|
||||
'target_vision_frame': stream_vision_frame,
|
||||
'temp_vision_frame': temp_vision_frame,
|
||||
'temp_vision_mask': temp_vision_mask
|
||||
})
|
||||
|
||||
+20
-3
@@ -94,8 +94,11 @@ AudioCodec : TypeAlias = Literal['opus']
|
||||
VideoCodec : TypeAlias = Literal['av1', 'vp8']
|
||||
|
||||
AomEncoder : TypeAlias = ctypes.Array[ctypes.c_char]
|
||||
AomDecoder : TypeAlias = ctypes.Array[ctypes.c_char]
|
||||
OpusEncoder : TypeAlias = ctypes.c_void_p
|
||||
OpusDecoder : TypeAlias = ctypes.c_void_p
|
||||
VpxEncoder : TypeAlias = ctypes.Array[ctypes.c_char]
|
||||
VpxDecoder : TypeAlias = ctypes.Array[ctypes.c_char]
|
||||
|
||||
BitRate : TypeAlias = int
|
||||
SampleRate : TypeAlias = int
|
||||
@@ -274,18 +277,32 @@ StreamMode = Literal['udp', 'v4l2']
|
||||
PeerConnection : TypeAlias = int
|
||||
SdpOffer : TypeAlias = str
|
||||
SdpAnswer : TypeAlias = str
|
||||
MediaDirection : TypeAlias = Literal['sendonly', 'recvonly']
|
||||
MediaDirection : TypeAlias = Literal['sendonly', 'recvonly', 'sendrecv']
|
||||
|
||||
RtcTrackInit : TypeAlias = Any
|
||||
|
||||
RtcVideoTrack : TypeAlias = int
|
||||
RtcAudioTrack : TypeAlias = int
|
||||
|
||||
RtcPeerAudio = TypedDict('RtcPeerAudio',
|
||||
{
|
||||
'sender_track': RtcAudioTrack,
|
||||
'receiver_track': RtcAudioTrack,
|
||||
'codec': AudioCodec,
|
||||
})
|
||||
|
||||
RtcPeerVideo = TypedDict('RtcPeerVideo',
|
||||
{
|
||||
'sender_track': RtcVideoTrack,
|
||||
'receiver_track': RtcVideoTrack,
|
||||
'codec': VideoCodec,
|
||||
})
|
||||
|
||||
RtcPeer = TypedDict('RtcPeer',
|
||||
{
|
||||
'peer_connection': PeerConnection,
|
||||
'video_track': RtcVideoTrack,
|
||||
'audio_track': RtcAudioTrack,
|
||||
'audio': NotRequired[RtcPeerAudio],
|
||||
'video': RtcPeerVideo,
|
||||
})
|
||||
RtcStore : TypeAlias = Dict[SessionId, List[RtcPeer]]
|
||||
|
||||
|
||||
+21
-35
@@ -1,19 +1,18 @@
|
||||
import tempfile
|
||||
import threading
|
||||
from functools import partial
|
||||
from typing import Iterator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, rtc, session_manager, state_manager
|
||||
from facefusion import metadata, rtc, rtc_store, session_manager, state_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.core import create_api
|
||||
from facefusion.core import common_pre_check
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.hash_helper import create_hash
|
||||
from facefusion.libraries import datachannel as datachannel_module
|
||||
from facefusion.types import VideoCodec
|
||||
from .assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
|
||||
@@ -37,6 +36,7 @@ def before_all() -> None:
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
asset_store.clear()
|
||||
rtc_store.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
@@ -45,16 +45,6 @@ def test_client() -> Iterator[TestClient]:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function')
|
||||
def create_event() -> threading.Event:
|
||||
return threading.Event()
|
||||
|
||||
|
||||
@pytest.mark.helper
|
||||
def set_event(session_id : str, media_buffer : bytes, timestamp : int, event : threading.Event) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
def test_stream_image(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
@@ -85,7 +75,7 @@ def test_stream_image(test_client : TestClient) -> None:
|
||||
|
||||
assert select_response.status_code == 200
|
||||
|
||||
with test_client.websocket_connect('/stream?mode=image', subprotocols =
|
||||
with test_client.websocket_connect('/stream?type=image&action=process', subprotocols =
|
||||
[
|
||||
'access_token.' + access_token
|
||||
]) as websocket:
|
||||
@@ -96,7 +86,7 @@ def test_stream_image(test_client : TestClient) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
|
||||
def test_stream_video(test_client : TestClient, create_event : threading.Event, video_codec : str) -> None:
|
||||
def test_stream_video(test_client : TestClient, video_codec : VideoCodec) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
@@ -124,27 +114,23 @@ def test_stream_video(test_client : TestClient, create_event : threading.Event,
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
with patch('facefusion.rtc.send_video_to_peers', side_effect = partial(set_event, event = create_event)):
|
||||
with test_client.websocket_connect('/stream?mode=video&codec=' + video_codec, subprotocols =
|
||||
[
|
||||
'access_token.' + access_token
|
||||
]) as websocket:
|
||||
websocket.send_bytes(chr(1).encode() + source_content)
|
||||
websocket.receive_text()
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
|
||||
peer_connection = rtc.create_peer_connection()
|
||||
rtc.add_video_track(peer_connection, 'recvonly', 'vp8', 96)
|
||||
rtc.add_audio_track(peer_connection, 'recvonly', 'opus', 111)
|
||||
sdp_offer = rtc.create_sdp_offer(peer_connection)
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
|
||||
stream_response = test_client.post('/stream', content = sdp_offer, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token,
|
||||
'Content-Type': 'application/sdp'
|
||||
})
|
||||
if video_codec == 'av1':
|
||||
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 35)
|
||||
if video_codec == 'vp8':
|
||||
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 96)
|
||||
|
||||
assert stream_response.status_code == 201
|
||||
rtc.add_audio_track(peer_connection, 'sendrecv', 'opus', 111)
|
||||
sdp_offer = rtc.create_sdp_offer(peer_connection)
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
create_event.wait(timeout = 10)
|
||||
with patch('facefusion.rtc.send_video'):
|
||||
stream_response = test_client.post('/stream?type=video&action=process', content = sdp_offer, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token,
|
||||
'Content-Type': 'application/sdp'
|
||||
})
|
||||
|
||||
assert create_event.is_set()
|
||||
assert stream_response.status_code == 201
|
||||
assert 'm=video' in stream_response.text
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
from tests.assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.codecs.aom_decoder import create, decode, destroy, read_resolution
|
||||
from facefusion.codecs.aom_encoder import create as create_encoder, encode
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.libraries import aom as aom_module
|
||||
from facefusion.vision import read_video_frame
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
def before_all() -> None:
|
||||
state_manager.init_item('download_providers', [ 'github', 'huggingface' ])
|
||||
|
||||
conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' ])
|
||||
|
||||
aom_module.pre_check()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_create() -> None:
|
||||
assert create()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_decode() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
aom_encoder = create_encoder(video_resolution, 1000, 1, 0)
|
||||
encoded_buffer = encode(aom_encoder, video_buffer, video_resolution, 0)
|
||||
decode_resolution = read_resolution(create(), encoded_buffer)
|
||||
|
||||
assert len(decode(create(), encoded_buffer)) == decode_resolution[0] * decode_resolution[1] * 3 // 2
|
||||
assert decode(create(), bytes()) == bytes()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_read_resolution() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
aom_encoder = create_encoder(video_resolution, 1000, 1, 0)
|
||||
encoded_buffer = encode(aom_encoder, video_buffer, video_resolution, 0)
|
||||
|
||||
assert read_resolution(create(), encoded_buffer)[0] >= video_resolution[0]
|
||||
assert read_resolution(create(), encoded_buffer)[1] >= video_resolution[1]
|
||||
assert read_resolution(create(), bytes()) is None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_destroy() -> None:
|
||||
aom_decoder = create()
|
||||
|
||||
with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock:
|
||||
destroy(aom_decoder)
|
||||
mock.assert_called_once_with(aom_decoder)
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from tests.assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encode_aom_buffer
|
||||
from facefusion.codecs.aom_encoder import create, destroy, encode
|
||||
from facefusion.common_helper import is_linux, is_macos, is_windows
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.hash_helper import create_hash
|
||||
@@ -22,27 +22,27 @@ def before_all() -> None:
|
||||
aom_module.pre_check()
|
||||
|
||||
|
||||
def test_create_aom_encoder() -> None:
|
||||
assert create_aom_encoder(1000, 8, 16, (320, 240))
|
||||
assert create_aom_encoder(0, 0, 0, (0, 0)) is None
|
||||
def test_create() -> None:
|
||||
assert create((320, 240), 1000, 8, 16)
|
||||
assert create((0, 0), 0, 0, 0) is None
|
||||
|
||||
|
||||
def test_encode_aom_buffer() -> None:
|
||||
def test_encode() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
aom_encoder = create_aom_encoder(1000, 1, 0, video_resolution)
|
||||
aom_encoder = create(video_resolution, 1000, 1, 0)
|
||||
|
||||
if is_linux() or is_windows():
|
||||
assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31'
|
||||
assert create_hash(encode(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31'
|
||||
|
||||
if is_macos():
|
||||
pytest.skip()
|
||||
|
||||
|
||||
def test_destroy_aom_encoder() -> None:
|
||||
aom_encoder = create_aom_encoder(1000, 8, 16, (320, 240))
|
||||
def test_destroy() -> None:
|
||||
aom_encoder = create((320, 240), 1000, 8, 16)
|
||||
|
||||
with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock:
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
destroy(aom_encoder)
|
||||
mock.assert_called_once_with(aom_encoder)
|
||||
@@ -0,0 +1,47 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from tests.assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.codecs.opus_decoder import create, decode, destroy
|
||||
from facefusion.codecs.opus_encoder import create as create_encoder, encode
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.ffmpeg import read_audio_buffer
|
||||
from facefusion.libraries import opus as opus_module
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
def before_all() -> None:
|
||||
state_manager.init_item('download_providers', [ 'github', 'huggingface' ])
|
||||
|
||||
conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3' ])
|
||||
|
||||
opus_module.pre_check()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_create() -> None:
|
||||
assert create(48000, 2)
|
||||
assert create(0, 0) is None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_decode() -> None:
|
||||
audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2)
|
||||
audio_sample = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0
|
||||
opus_encoder = create_encoder(48000, 2)
|
||||
encoded_buffer = encode(opus_encoder, audio_sample.tobytes(), 960)
|
||||
opus_decoder = create(48000, 2)
|
||||
|
||||
assert len(decode(opus_decoder, encoded_buffer, 960, 2)) == 960 * 2 * 4
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_destroy() -> None:
|
||||
opus_decoder = create(48000, 2)
|
||||
|
||||
with patch.object(opus_module.create_static_library(), 'opus_decoder_destroy') as mock:
|
||||
destroy(opus_decoder)
|
||||
mock.assert_called_once_with(opus_decoder)
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from tests.assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer
|
||||
from facefusion.codecs.opus_encoder import create, destroy, encode
|
||||
from facefusion.common_helper import is_linux, is_macos, is_windows
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.ffmpeg import read_audio_buffer
|
||||
@@ -22,26 +22,26 @@ def before_all() -> None:
|
||||
opus_module.pre_check()
|
||||
|
||||
|
||||
def test_create_opus_encoder() -> None:
|
||||
assert create_opus_encoder(48000, 2)
|
||||
assert create_opus_encoder(0, 0) is None
|
||||
def test_create() -> None:
|
||||
assert create(48000, 2)
|
||||
assert create(0, 0) is None
|
||||
|
||||
|
||||
def test_encode_opus_buffer() -> None:
|
||||
def test_encode() -> None:
|
||||
audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2)
|
||||
audio_sample = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0
|
||||
opus_encoder = create_opus_encoder(48000, 2)
|
||||
opus_encoder = create(48000, 2)
|
||||
|
||||
if is_linux() or is_windows():
|
||||
assert create_hash(encode_opus_buffer(opus_encoder, audio_sample.tobytes(), 960)) == '8abe71cf'
|
||||
assert create_hash(encode(opus_encoder, audio_sample.tobytes(), 960)) == '8abe71cf'
|
||||
|
||||
if is_macos():
|
||||
pytest.skip()
|
||||
|
||||
|
||||
def test_destroy_opus_encoder() -> None:
|
||||
opus_encoder = create_opus_encoder(48000, 2)
|
||||
def test_destroy() -> None:
|
||||
opus_encoder = create(48000, 2)
|
||||
|
||||
with patch.object(opus_module.create_static_library(), 'opus_encoder_destroy') as mock:
|
||||
destroy_opus_encoder(opus_encoder)
|
||||
destroy(opus_encoder)
|
||||
mock.assert_called_once_with(opus_encoder)
|
||||
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
from tests.assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.codecs.vpx_decoder import create, decode, destroy, read_resolution
|
||||
from facefusion.codecs.vpx_encoder import create as create_encoder, encode
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.libraries import vpx as vpx_module
|
||||
from facefusion.vision import read_video_frame
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
def before_all() -> None:
|
||||
state_manager.init_item('download_providers', [ 'github', 'huggingface' ])
|
||||
|
||||
conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' ])
|
||||
|
||||
vpx_module.pre_check()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_create() -> None:
|
||||
assert create()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_decode() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
vpx_encoder = create_encoder(video_resolution, 1000, 1, 0)
|
||||
encoded_buffer = encode(vpx_encoder, video_buffer, video_resolution, 0)
|
||||
vpx_decoder = create()
|
||||
|
||||
assert len(decode(vpx_decoder, encoded_buffer)) == video_resolution[0] * video_resolution[1] * 3 // 2
|
||||
assert decode(vpx_decoder, bytes()) == bytes()
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_read_resolution() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
vpx_encoder = create_encoder(video_resolution, 1000, 1, 0)
|
||||
encoded_buffer = encode(vpx_encoder, video_buffer, video_resolution, 0)
|
||||
vpx_decoder = create()
|
||||
|
||||
assert read_resolution(vpx_decoder, encoded_buffer) == video_resolution
|
||||
assert read_resolution(vpx_decoder, bytes()) is None
|
||||
|
||||
|
||||
#TODO: needs review
|
||||
def test_destroy() -> None:
|
||||
vpx_decoder = create()
|
||||
|
||||
with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock:
|
||||
destroy(vpx_decoder)
|
||||
mock.assert_called_once_with(vpx_decoder)
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from tests.assert_helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer
|
||||
from facefusion.codecs.vpx_encoder import create, destroy, encode
|
||||
from facefusion.common_helper import is_linux, is_macos, is_windows
|
||||
from facefusion.download import conditional_download
|
||||
from facefusion.hash_helper import create_hash
|
||||
@@ -22,27 +22,27 @@ def before_all() -> None:
|
||||
vpx_module.pre_check()
|
||||
|
||||
|
||||
def test_create_vpx_encoder() -> None:
|
||||
assert create_vpx_encoder(1000, 8, 16, (320, 240))
|
||||
assert create_vpx_encoder(0, 0, 0, (0, 0)) is None
|
||||
def test_create() -> None:
|
||||
assert create((320, 240), 1000, 8, 16)
|
||||
assert create((0, 0), 0, 0, 0) is None
|
||||
|
||||
|
||||
def test_encode_vpx_buffer() -> None:
|
||||
def test_encode() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
vpx_encoder = create_vpx_encoder(1000, 1, 0, video_resolution)
|
||||
vpx_encoder = create(video_resolution, 1000, 1, 0)
|
||||
|
||||
if is_linux() or is_windows():
|
||||
assert create_hash(encode_vpx_buffer(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f'
|
||||
assert create_hash(encode(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f'
|
||||
|
||||
if is_macos():
|
||||
pytest.skip()
|
||||
|
||||
|
||||
def test_destroy_vpx_encoder() -> None:
|
||||
vpx_encoder = create_vpx_encoder(1000, 8, 16, (320, 240))
|
||||
def test_destroy() -> None:
|
||||
vpx_encoder = create((320, 240), 1000, 8, 16)
|
||||
|
||||
with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock:
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
destroy(vpx_encoder)
|
||||
mock.assert_called_once_with(vpx_encoder)
|
||||
+39
-29
@@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
|
||||
from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, detect_sdp_media, send_audio_to_peers, send_video_to_peers, set_remote_description
|
||||
from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, get_payload_type, send_audio, send_video, set_remote_description
|
||||
from facefusion.types import RtcPeer
|
||||
|
||||
|
||||
@@ -57,48 +57,56 @@ def test_create_sdp_answer() -> None:
|
||||
|
||||
assert 'm=video' in sdp_answer
|
||||
assert 'VP8/90000' in sdp_answer
|
||||
assert 'a=ssrc:42 cname:video' in sdp_answer
|
||||
assert 'm=audio' in sdp_answer
|
||||
assert 'opus/48000/2' in sdp_answer
|
||||
assert 'a=ssrc:43 cname:audio' in sdp_answer
|
||||
assert 'a=recvonly' in sdp_answer
|
||||
|
||||
assert datachannel_library.rtcDeletePeerConnection(sender_peer_connection) == 0
|
||||
assert datachannel_library.rtcDeletePeerConnection(receiver_peer_connection) == 0
|
||||
|
||||
|
||||
def test_send_audio_to_peers() -> None:
|
||||
def test_send_video() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
peer_connection = create_peer_connection()
|
||||
audio_track = add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
rtc_peers : List[RtcPeer] =\
|
||||
[
|
||||
video_track = add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
rtc_peer : RtcPeer =\
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video':
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video_track': 0,
|
||||
'audio_track': audio_track
|
||||
'sender_track': video_track,
|
||||
'receiver_track': video_track,
|
||||
'codec': 'vp8'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
send_audio_to_peers(rtc_peers, bytes(960), 0)
|
||||
send_video(rtc_peer, bytes(1024), 0)
|
||||
|
||||
datachannel_library.rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
|
||||
def test_send_video_to_peers() -> None:
|
||||
def test_send_audio() -> None:
|
||||
datachannel_library = datachannel_module.create_static_library()
|
||||
peer_connection = create_peer_connection()
|
||||
video_track = add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
rtc_peers : List[RtcPeer] =\
|
||||
[
|
||||
audio_track = add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
rtc_peer : RtcPeer =\
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video':
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video_track': video_track,
|
||||
'audio_track': 0
|
||||
'sender_track': 0,
|
||||
'receiver_track': 0,
|
||||
'codec': 'vp8'
|
||||
},
|
||||
'audio':
|
||||
{
|
||||
'sender_track': audio_track,
|
||||
'receiver_track': audio_track,
|
||||
'codec': 'opus'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
send_video_to_peers(rtc_peers, bytes(1024), 0)
|
||||
send_audio(rtc_peer, bytes(960), 0)
|
||||
|
||||
datachannel_library.rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
@@ -110,8 +118,12 @@ def test_delete_peers() -> None:
|
||||
[
|
||||
{
|
||||
'peer_connection': peer_connection,
|
||||
'video_track': 0,
|
||||
'audio_track': 0
|
||||
'video':
|
||||
{
|
||||
'sender_track': 0,
|
||||
'receiver_track': 0,
|
||||
'codec': 'vp8'
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -120,16 +132,14 @@ def test_delete_peers() -> None:
|
||||
assert datachannel_library.rtcDeletePeerConnection(peer_connection) == -1
|
||||
|
||||
|
||||
def test_detect_sdp_media() -> None:
|
||||
def test_get_payload_type() -> None:
|
||||
peer_connection = create_peer_connection()
|
||||
add_video_track(peer_connection, 'sendonly', 'vp8', 96)
|
||||
add_audio_track(peer_connection, 'sendonly', 'opus', 111)
|
||||
sdp_offer = create_sdp_offer(peer_connection)
|
||||
sdp_payload = detect_sdp_media(sdp_offer)
|
||||
|
||||
assert sdp_payload.get('video').get('codec') == 'vp8'
|
||||
assert sdp_payload.get('video').get('payload_type') == 96
|
||||
assert sdp_payload.get('audio').get('codec') == 'opus'
|
||||
assert sdp_payload.get('audio').get('payload_type') == 111
|
||||
assert get_payload_type(sdp_offer, 'vp8') == 96
|
||||
assert get_payload_type(sdp_offer, 'opus') == 111
|
||||
assert get_payload_type(sdp_offer, 'av1') == 0
|
||||
|
||||
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
import asyncio
|
||||
import queue
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import pytest
|
||||
from numpy.typing import NDArray
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from facefusion.apis.stream_helper import encode_audio_loop, encode_video_loop, handle_video_stream
|
||||
from facefusion.hash_helper import create_hash
|
||||
from facefusion.types import VideoCodec, VisionFrame
|
||||
|
||||
|
||||
def _make_handler_websocket(events : list[Any]) -> MagicMock:
|
||||
mock = MagicMock()
|
||||
mock.scope = {}
|
||||
mock.client_state = WebSocketState.CONNECTED
|
||||
mock.accept = AsyncMock()
|
||||
mock.send_text = AsyncMock()
|
||||
mock.close = AsyncMock()
|
||||
mock.receive = AsyncMock(side_effect = events)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_video_packet(frame : NDArray[Any]) -> bytes:
|
||||
_, encoded = cv2.imencode('.jpg', frame)
|
||||
return b'\x01' + encoded.tobytes()
|
||||
|
||||
|
||||
def _make_audio_packet(samples : NDArray[Any]) -> bytes:
|
||||
return b'\x02' + samples.tobytes()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
|
||||
def test_encode_video_loop(video_codec : VideoCodec) -> None:
|
||||
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
|
||||
black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8)
|
||||
prefix = 'facefusion.apis.stream_helper.'
|
||||
|
||||
create_name = prefix + 'create_aom_encoder'
|
||||
encode_name = prefix + 'encode_aom_buffer'
|
||||
destroy_name = prefix + 'destroy_aom_encoder'
|
||||
|
||||
if video_codec == 'vp8':
|
||||
create_name = prefix + 'create_vpx_encoder'
|
||||
encode_name = prefix + 'encode_vpx_buffer'
|
||||
destroy_name = prefix + 'destroy_vpx_encoder'
|
||||
|
||||
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
||||
patch(create_name, return_value = MagicMock()), \
|
||||
patch(encode_name, return_value = b'encoded'), \
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
||||
patch(create_name, return_value = MagicMock()), \
|
||||
patch(encode_name, return_value = b'encoded'), \
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
assert mock_rtc.send_video_to_peers.call_count == 3
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(black_frame)
|
||||
with patch(create_name, return_value = MagicMock()), \
|
||||
patch(destroy_name) as mock_destroy, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = frame), \
|
||||
patch(create_name, return_value = MagicMock()), \
|
||||
patch(encode_name, return_value = b''), \
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store'), \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(small_frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(prefix + 'process_vision_frame', return_value = large_frame), \
|
||||
patch(create_name, return_value = MagicMock()) as mock_create, \
|
||||
patch(encode_name, return_value = b'encoded'), \
|
||||
patch(destroy_name) as mock_destroy, \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
assert mock_create.call_count == 2
|
||||
assert mock_destroy.call_count == 2
|
||||
mock_rtc.send_video_to_peers.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
vision_frame_queue.put(None)
|
||||
with patch(create_name, return_value = None), \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
|
||||
|
||||
# TODO: refine test
|
||||
def test_encode_audio_loop() -> None:
|
||||
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
|
||||
|
||||
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_called_once()
|
||||
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
assert mock_rtc.send_audio_to_peers.call_count == 2
|
||||
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
|
||||
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'), \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_not_called()
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(audio_chunk)
|
||||
audio_chunk_queue.put(None)
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'), \
|
||||
patch('facefusion.apis.stream_helper.rtc'):
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
audio_chunk_queue.put(b'')
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_not_called()
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
|
||||
# TODO: refine test
|
||||
def test_handle_video_stream() -> None:
|
||||
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
video_packet = _make_video_packet(frame)
|
||||
audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32))
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
||||
patch('facefusion.apis.stream_helper.encode_video_loop') as mock_loop, \
|
||||
patch('facefusion.apis.stream_helper.encode_audio_loop'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
websocket.accept.assert_called_once_with(subprotocol = 'proto')
|
||||
websocket.send_text.assert_called_once_with('ready')
|
||||
websocket.close.assert_called_once()
|
||||
mock_rtc.init_peers.assert_called_once_with('session-1')
|
||||
mock_rtc.delete_peers.assert_called_once_with('session-1')
|
||||
_, _, loop_session_id, loop_resolution = mock_loop.call_args[0]
|
||||
assert loop_session_id == 'session-1'
|
||||
assert loop_resolution == (64, 64)
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
websocket.accept.assert_called_once()
|
||||
websocket.send_text.assert_not_called()
|
||||
mock_rtc.init_peers.assert_not_called()
|
||||
|
||||
websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ])
|
||||
with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \
|
||||
patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
||||
patch('facefusion.apis.stream_helper.encode_video_loop'), \
|
||||
patch('facefusion.apis.stream_helper.encode_audio_loop') as mock_audio_loop, \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'):
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
audio_queue = mock_audio_loop.call_args[0][1]
|
||||
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'
|
||||
Reference in New Issue
Block a user