diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index fabdd159..207a7be6 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -5,32 +5,32 @@ from starlette.websockets import WebSocket from facefusion import session_context, session_manager from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import connect_rtc, handle_image_stream, handle_video_stream +from facefusion.apis.stream_helper import process_image, process_video async def websocket_stream(websocket : WebSocket) -> None: - stream_mode = websocket.query_params.get('mode') + stream_type = websocket.query_params.get('type') - if stream_mode == 'image': - return await handle_image_stream(websocket) - - if stream_mode == 'video': - return await handle_video_stream(websocket) + if stream_type == 'image': + return await process_image(websocket) return await websocket.close(1008) async def post_stream(request : Request) -> Response: + stream_type = request.query_params.get('type') content_type = request.headers.get('content-type') access_token = extract_access_token(request.scope) session_id = session_manager.find_session_id(access_token) + session_context.set_session_id(session_id) if content_type == 'application/sdp' and session_id: sdp_offer = await request.body() - sdp_answer = connect_rtc(session_id, sdp_offer.decode()) - if sdp_answer: + if stream_type == 'video': + sdp_answer = process_video(session_id, sdp_offer.decode()) + return Response(sdp_answer, status_code = HTTP_201_CREATED, media_type = 'application/sdp') return Response(status_code = HTTP_404_NOT_FOUND) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 28be711d..d83b32fe 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,92 +1,24 @@ import asyncio -import queue # TODO: try deque +import ctypes import time from collections.abc import AsyncIterator -from functools import partial -from typing import Optional, Tuple, cast, get_args +from typing import Optional import cv2 import numpy from starlette.websockets import WebSocket, WebSocketState -from facefusion import rtc, rtc_store, session_context, session_manager, state_manager +from facefusion import rtc, rtc_store, session_context, session_manager, state_manager, streamer from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token -from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encode_aom_buffer -from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer -from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer -from facefusion.streamer import process_vision_frame -from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame +from facefusion.audio import create_empty_audio_frame +from facefusion.codecs import aom_decoder, aom_encoder, opus_decoder, opus_encoder, vpx_decoder, vpx_encoder +from facefusion.libraries import datachannel as datachannel_module +from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, OpusDecoder, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder -# TODO: refine this method -async def handle_video_stream(websocket : WebSocket) -> None: - subprotocol = get_sec_websocket_protocol(websocket.scope) - access_token = extract_access_token(websocket.scope) - session_id = session_manager.find_session_id(access_token) - session_context.set_session_id(session_id) - video_codec : VideoCodec = 'av1' - audio_codec : AudioCodec = 'opus' - - if websocket.query_params.get('codec') in get_args(VideoCodec): - video_codec = cast(VideoCodec, websocket.query_params.get('codec')) - - await websocket.accept(subprotocol = subprotocol) - - if session_id: - stream_frames = receive_stream_frames(websocket) - first_vision_frame : Optional[VisionFrame] = None - - async for first_frame_type, first_frame_buffer in stream_frames: - if first_frame_type == 1: - first_vision_frame = cv2.imdecode(numpy.frombuffer(first_frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) - break - - if numpy.any(first_vision_frame): - resolution : Resolution = (first_vision_frame.shape[1], first_vision_frame.shape[0]) - vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue() - audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue() - audio_temp = numpy.array([], dtype = numpy.float32) - - vision_frame_queue.put(first_vision_frame) - rtc_store.init_peers(session_id) - - event_loop = asyncio.get_running_loop() - - video_encode_task = event_loop.run_in_executor(None, encode_video_loop, video_codec, vision_frame_queue, session_id, resolution) - audio_encode_task = event_loop.run_in_executor(None, encode_audio_loop, audio_codec, audio_chunk_queue, session_id) - await websocket.send_text('ready') - - async for frame_type, frame_buffer in stream_frames: - if frame_type == 1: - vision_frame = cv2.imdecode(numpy.frombuffer(frame_buffer, numpy.uint8), cv2.IMREAD_COLOR) - - if numpy.any(vision_frame): - if vision_frame_queue.qsize(): - vision_frame_queue.get_nowait() - vision_frame_queue.put(vision_frame) - - if frame_type == 2: - audio_temp = numpy.concatenate([ audio_temp, numpy.frombuffer(frame_buffer, dtype = numpy.float32) ]) - - while len(audio_temp) >= 1920: - audio_chunk_queue.put(audio_temp[:1920].tobytes()) - audio_temp = audio_temp[1920:] - - vision_frame_queue.put(None) - audio_chunk_queue.put(None) - - await video_encode_task - await audio_encode_task - - rtc_store.delete_peers(session_id) - - if websocket.client_state == WebSocketState.CONNECTED: - await websocket.close() - - -# TODO: extract shared session setup from handle_image_stream and handle_video_stream, guard session_id like handle_video_stream -async def handle_image_stream(websocket : WebSocket) -> None: +#TODO: needs review +async def process_image(websocket : WebSocket) -> None: subprotocol = get_sec_websocket_protocol(websocket.scope) access_token = extract_access_token(websocket.scope) session_id = session_manager.find_session_id(access_token) @@ -99,7 +31,7 @@ async def handle_image_stream(websocket : WebSocket) -> None: capture_vision_frame = await anext(receive_vision_frames(websocket), None) if numpy.any(capture_vision_frame): - output_vision_frame = process_vision_frame(capture_vision_frame) + output_vision_frame = streamer.process_frame(create_empty_audio_frame(), capture_vision_frame) is_success, output_frame_buffer = cv2.imencode('.jpg', output_vision_frame) if is_success: @@ -109,108 +41,7 @@ async def handle_image_stream(websocket : WebSocket) -> None: await websocket.close() -# TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop -def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: - rtc_peers = rtc_store.get_peers(session_id) - - if rtc_peers is not None: - sdp_media = rtc.detect_sdp_media(sdp_offer) - peer_connection : PeerConnection = rtc.create_peer_connection() - rtc.set_remote_description(peer_connection, sdp_offer) - - audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', sdp_media.get('audio').get('codec'), sdp_media.get('audio').get('payload_type')) - video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', sdp_media.get('video').get('codec'), sdp_media.get('video').get('payload_type')) - local_sdp = rtc.create_sdp_answer(peer_connection) - - if local_sdp: - rtc_peer : RtcPeer =\ - { - 'peer_connection': peer_connection, - 'video_track': video_track, - 'audio_track': audio_track - } - rtc_peers.append(rtc_peer) - - return local_sdp - - return None - - -def encode_video_loop(video_codec : VideoCodec, vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None: - create_encoder = partial(create_aom_encoder, 4500, 8, 10) - destroy_encoder = destroy_aom_encoder - encode_buffer = encode_aom_buffer - - if video_codec == 'vp8': - create_encoder = partial(create_vpx_encoder, 4500, 8, 16) - destroy_encoder = destroy_vpx_encoder # type:ignore[assignment] - encode_buffer = encode_vpx_buffer # type:ignore[assignment] - - encoder = create_encoder(frame_resolution) - temp_resolution = frame_resolution - timestamp = 0 - - vision_frame = vision_frame_queue.get() - - while numpy.any(vision_frame) and encoder: - output_vision_frame = process_vision_frame(vision_frame) - output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) - - if output_resolution == temp_resolution: - output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() - output_frame_buffer = encode_buffer(encoder, output_frame_buffer, output_resolution, timestamp) - rtc_peers = rtc_store.get_peers(session_id) - - if output_frame_buffer and rtc_peers: - video_timestamp = int(time.monotonic() * 90000) - rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp) - - timestamp += 1 - vision_frame = vision_frame_queue.get() - else: - destroy_encoder(encoder) - temp_resolution = output_resolution - encoder = create_encoder(temp_resolution) - timestamp = 0 - - if encoder: - destroy_encoder(encoder) - - -def encode_audio_loop(audio_codec : AudioCodec, audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None: - opus_encoder = create_opus_encoder(48000, 2) - audio_timestamp = 0 - - audio_chunk = audio_chunk_queue.get() - - while audio_chunk: # TODO: improve this condition with b'' - audio_buffer = encode_opus_buffer(opus_encoder, audio_chunk, 960) - rtc_peers = rtc_store.get_peers(session_id) - - if audio_buffer and rtc_peers: - rtc.send_audio_to_peers(rtc_peers, audio_buffer, audio_timestamp) - - audio_timestamp += 960 - audio_chunk = audio_chunk_queue.get() - - if opus_encoder: - destroy_opus_encoder(opus_encoder) - - -# TODO: needs refinement -async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]: - websocket_event = await websocket.receive() - - while websocket_event.get('type') == 'websocket.receive': - frame_buffer = websocket_event.get('bytes') or bytes() - - if len(frame_buffer) > 1: - yield frame_buffer[0], frame_buffer[1:] - - websocket_event = await websocket.receive() - - -# TODO: needs refinement, does it receive frames or a buffer? +#TODO: needs review async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]: websocket_event = await websocket.receive() @@ -222,3 +53,299 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr yield vision_frame websocket_event = await websocket.receive() + + +#TODO: needs review +def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: + video_codec : VideoCodec = 'vp8' + av1_payload_type = rtc.get_payload_type(sdp_offer, 'av1') + + if av1_payload_type: + video_codec = 'av1' + + video_payload_type = rtc.get_payload_type(sdp_offer, video_codec) + + if not video_payload_type: + return None + + peer_connection : PeerConnection = rtc.create_peer_connection() + video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', video_codec, video_payload_type) + video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', video_codec, video_payload_type) + + audio_codec : AudioCodec = 'opus' + audio_payload_type = rtc.get_payload_type(sdp_offer, audio_codec) + audio_receiver_track = None + audio_sender_track = None + + if audio_payload_type: + audio_receiver_track = rtc.add_audio_track(peer_connection, 'recvonly', audio_codec, audio_payload_type) + audio_sender_track = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, audio_payload_type) + + rtc.set_remote_description(peer_connection, sdp_offer) + local_sdp = rtc.create_sdp_answer(peer_connection) + + if local_sdp: + rtc_peer : RtcPeer =\ + { + 'peer_connection': peer_connection, + 'video': + { + 'sender_track': video_sender_track, + 'receiver_track': video_receiver_track, + 'codec': video_codec + } + } + + if audio_receiver_track and audio_sender_track: + rtc_peer['audio'] =\ + { + 'sender_track': audio_sender_track, + 'receiver_track': audio_receiver_track, + 'codec': audio_codec + } + + rtc_store.init_peers(session_id) + rtc_store.get_peers(session_id).append(rtc_peer) + + event_loop = asyncio.get_event_loop() + event_loop.run_in_executor(None, run_peer_loop, session_id, rtc_peer) + + return local_sdp + + +#TODO: needs review +def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: + datachannel_library = datachannel_module.create_static_library() + video_info = rtc_peer.get('video') + video_codec = video_info.get('codec') + video_decoder = create_video_decoder(video_codec) + audio_info = rtc_peer.get('audio') + audio_decoder = opus_decoder.create(48000, 2) if audio_info else None + video_receive_buffer = ctypes.create_string_buffer(512 * 1024) + audio_receive_buffer = ctypes.create_string_buffer(8 * 1024) + + frame_buffer = poll_for_buffer(datachannel_library, video_info.get('receiver_track'), video_receive_buffer, 30.0) + + if frame_buffer is None: + cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder) + return + + resolution = read_video_resolution(video_codec, video_decoder, frame_buffer) + + if resolution is None: + cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder) + return + + vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer, resolution) + + if vision_frame is None: + cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder) + return + + audio_frame = create_empty_audio_frame() + video_encoder = create_video_encoder(video_codec, resolution) + audio_encoder = opus_encoder.create(48000, 2) + frame_index = 0 + + while True: + if audio_info and audio_decoder: + audio_frame = receive_audio_frame(datachannel_library, audio_info.get('receiver_track'), audio_decoder, audio_receive_buffer) + + output_vision_frame = streamer.process_frame(audio_frame, vision_frame) + output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) + + if output_resolution != resolution: + resolution = output_resolution + destroy_video_encoder(video_codec, video_encoder) + video_encoder = create_video_encoder(video_codec, resolution) + + raw_vision_frame = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420) + + if video_codec == 'av1': + encoded_video_buffer = aom_encoder.encode(video_encoder, raw_vision_frame.tobytes(), resolution, frame_index) + if video_codec == 'vp8': + encoded_video_buffer = vpx_encoder.encode(video_encoder, raw_vision_frame.tobytes(), resolution, frame_index) + + if encoded_video_buffer: + video_timestamp = int(time.monotonic() * 90000) + rtc.send_video(rtc_peer, encoded_video_buffer, video_timestamp) + + if audio_encoder and audio_frame is not None and audio_frame.size > 0: + encoded_audio_buffer = opus_encoder.encode(audio_encoder, audio_frame.tobytes(), 960) + + if encoded_audio_buffer: + audio_timestamp = int(time.monotonic() * 48000) + rtc.send_audio(rtc_peer, encoded_audio_buffer, audio_timestamp) + + frame_index += 1 + + next_frame = drain_to_latest_frame(datachannel_library, video_info.get('receiver_track'), video_codec, video_decoder, video_receive_buffer, resolution) + + if next_frame is not None: + vision_frame = next_frame + continue + + next_frame = poll_for_frame(datachannel_library, video_info.get('receiver_track'), video_codec, video_decoder, video_receive_buffer, resolution, 30.0) + + if next_frame is None: + break + + vision_frame = next_frame + + destroy_video_encoder(video_codec, video_encoder) + opus_encoder.destroy(audio_encoder) + cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder) + + +#TODO: needs review +def cleanup_peer(session_id : SessionId, rtc_peer : RtcPeer, video_codec : VideoCodec, video_decoder : Optional[VpxDecoder | AomDecoder], audio_decoder : Optional[OpusDecoder]) -> None: + if video_decoder: + if video_codec == 'av1': + aom_decoder.destroy(video_decoder) + if video_codec == 'vp8': + vpx_decoder.destroy(video_decoder) + + if audio_decoder: + opus_decoder.destroy(audio_decoder) + + rtc_store.delete_peers(session_id) + + +#TODO: needs review +def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomDecoder]: + if video_codec == 'av1': + return aom_decoder.create() + if video_codec == 'vp8': + return vpx_decoder.create() + + return None + + +#TODO: needs review - remove as both are the same +def create_video_encoder(video_codec : VideoCodec, resolution : Resolution) -> Optional[VpxEncoder | AomEncoder]: + if video_codec == 'av1': + return aom_encoder.create(resolution, 8000, 8, 10) + if video_codec == 'vp8': + return vpx_encoder.create(resolution, 8000, 8, 10) + + return None + + +#TODO: needs review - remove as this is a trivial helper +def destroy_video_encoder(video_codec : VideoCodec, video_encoder : Optional[VpxEncoder | AomEncoder]) -> None: + if video_codec == 'av1': + aom_encoder.destroy(video_encoder) + if video_codec == 'vp8': + vpx_encoder.destroy(video_encoder) + + +def read_video_resolution(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes) -> Optional[Resolution]: + if video_codec == 'av1': + return aom_decoder.read_resolution(video_decoder, frame_buffer) + if video_codec == 'vp8': + return vpx_decoder.read_resolution(video_decoder, frame_buffer) + + return None + + +def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes, frame_resolution : Resolution) -> Optional[VisionFrame]: + output_buffer = bytes() + + if video_codec == 'av1': + output_buffer = aom_decoder.decode(video_decoder, frame_buffer) + if video_codec == 'vp8': + output_buffer = vpx_decoder.decode(video_decoder, frame_buffer) + + if output_buffer: + frame_width, frame_height = frame_resolution + yuv_frame = numpy.frombuffer(output_buffer, dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width)) + return cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420) + + return None + + +#TODO: needs review +def receive_audio_frame(datachannel_library : ctypes.CDLL, audio_track : int, audio_decoder : OpusDecoder, receive_buffer : ctypes.Array[ctypes.c_char]) -> AudioFrame: + buffer_size = ctypes.c_int(8 * 1024) + receive_output = datachannel_library.rtcReceiveMessage(audio_track, receive_buffer, ctypes.byref(buffer_size)) + + if receive_output == 0 and buffer_size.value > 0: + opus_buffer = receive_buffer.raw[:buffer_size.value] + output_buffer = opus_decoder.decode(audio_decoder, opus_buffer, 960, 2) + + if output_buffer: + return numpy.frombuffer(output_buffer, dtype = numpy.float32) + + return create_empty_audio_frame() + + +def receive_video_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char]) -> Optional[bytes]: + buffer_size = ctypes.c_int(512 * 1024) + receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size)) + + if receive_output == 0 and buffer_size.value > 0: + return receive_buffer.raw[:buffer_size.value] + + return None + + +def poll_for_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char], timeout : float) -> Optional[bytes]: + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer) + + if frame_buffer is not None: + return frame_buffer + + time.sleep(0.001) + + return None + + +#TODO: needs review +def poll_for_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], frame_resolution : Resolution, timeout : float) -> Optional[VisionFrame]: + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + vision_frame = try_receive_frame(datachannel_library, video_track, video_codec, video_decoder, receive_buffer, frame_resolution) + + if vision_frame is not None: + return vision_frame + + time.sleep(0.001) + + return None + + +#TODO: needs review +def try_receive_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], frame_resolution : Resolution) -> Optional[VisionFrame]: + frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer) + + if frame_buffer: + return decode_video_frame(video_codec, video_decoder, frame_buffer, frame_resolution) + + return None + + +#TODO: needs review +def drain_to_latest_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], frame_resolution : Resolution) -> Optional[VisionFrame]: + last_vision_frame = numpy.empty(0) + buffer_size = ctypes.c_int(512 * 1024) + receive_output = 0 + + while receive_output == 0 and buffer_size.value > 0: + buffer_size.value = 512 * 1024 + receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size)) + + if receive_output == 0 and buffer_size.value > 0: + frame_buffer = receive_buffer.raw[:buffer_size.value] + vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer, frame_resolution) + + if numpy.any(vision_frame): + last_vision_frame = vision_frame + + if numpy.any(last_vision_frame): + return last_vision_frame + + return None diff --git a/facefusion/codecs/aom_decoder.py b/facefusion/codecs/aom_decoder.py new file mode 100644 index 00000000..46f3c909 --- /dev/null +++ b/facefusion/codecs/aom_decoder.py @@ -0,0 +1,82 @@ +import ctypes +from typing import Optional + +from facefusion.libraries import aom as aom_module +from facefusion.types import AomDecoder, Resolution + + +def create() -> Optional[AomDecoder]: + aom_library = aom_module.create_static_library() + + if aom_library: + aom_decoder = ctypes.create_string_buffer(128) + aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_dx_algo') + + if aom_library.aom_codec_dec_init_ver(aom_decoder, ctypes.byref(aom_codec), None, 0, 22) == 0: + return aom_decoder + + return None + + +#TODO: needs review +def decode(aom_decoder : AomDecoder, input_buffer : bytes) -> bytes: + aom_library = aom_module.create_static_library() + output_buffer = bytes() + + if aom_library and input_buffer: + input_total = len(input_buffer) + temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer) + + if aom_library.aom_codec_decode(aom_decoder, temp_buffer, input_total, None) == 0: + frame_pointer = aom_library.aom_codec_get_frame(aom_decoder, ctypes.byref(ctypes.c_void_p(0))) + + if frame_pointer: + output_buffer = collect(frame_pointer) + + return output_buffer + + +#TODO: needs review +def collect(frame_pointer : int) -> bytes: + frame_width = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1 + frame_height = ctypes.c_uint.from_address(frame_pointer + 32).value & ~1 + planes_offset = frame_pointer + 64 + strides_offset = frame_pointer + 88 + output_buffer = bytes() + + for index in range(3): + plane_pointer = ctypes.c_void_p.from_address(planes_offset + index * 8).value + stride = ctypes.c_int.from_address(strides_offset + index * 4).value + plane_width = frame_width >> (index > 0) + plane_height = frame_height >> (index > 0) + + for row in range(plane_height): + output_buffer += ctypes.string_at(plane_pointer + row * stride, plane_width) + + return output_buffer + + +#TODO: needs review +def read_resolution(aom_decoder : AomDecoder, input_buffer : bytes) -> Optional[Resolution]: + aom_library = aom_module.create_static_library() + + if aom_library and input_buffer: + input_total = len(input_buffer) + temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer) + + if aom_library.aom_codec_decode(aom_decoder, temp_buffer, input_total, None) == 0: + frame_pointer = aom_library.aom_codec_get_frame(aom_decoder, ctypes.byref(ctypes.c_void_p(0))) + + if frame_pointer: + frame_width = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1 + frame_height = ctypes.c_uint.from_address(frame_pointer + 32).value & ~1 + return frame_width, frame_height + + return None + + +def destroy(aom_decoder : AomDecoder) -> None: + aom_library = aom_module.create_static_library() + + if aom_library: + aom_library.aom_codec_destroy(aom_decoder) diff --git a/facefusion/codecs/aom.py b/facefusion/codecs/aom_encoder.py similarity index 84% rename from facefusion/codecs/aom.py rename to facefusion/codecs/aom_encoder.py index ddb42608..0531b8c9 100644 --- a/facefusion/codecs/aom.py +++ b/facefusion/codecs/aom_encoder.py @@ -6,7 +6,7 @@ from facefusion.libraries import aom as aom_module from facefusion.types import AomEncoder, BitRate, Resolution -def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[AomEncoder]: +def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[AomEncoder]: aom_library = aom_module.create_static_library() if aom_library: @@ -33,7 +33,7 @@ def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, fram return None -def encode_aom_buffer(aom_encoder : AomEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes: +def encode(aom_encoder : AomEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes: aom_library = aom_module.create_static_library() output_buffer = bytes() @@ -42,12 +42,12 @@ def encode_aom_buffer(aom_encoder : AomEncoder, input_buffer : bytes, frame_reso encode_buffer = ctypes.create_string_buffer(input_buffer) if aom_library.aom_img_wrap(temp_buffer, 0x102, frame_resolution[0], frame_resolution[1], 1, encode_buffer) and aom_library.aom_codec_encode(aom_encoder, temp_buffer, frame_index, 1, 0, 1) == 0: - output_buffer = collect_aom_buffer(aom_encoder) + output_buffer = collect(aom_encoder) return output_buffer -def collect_aom_buffer(aom_encoder : AomEncoder) -> bytes: +def collect(aom_encoder : AomEncoder) -> bytes: aom_library = aom_module.create_static_library() output_buffer = bytes() @@ -65,7 +65,7 @@ def collect_aom_buffer(aom_encoder : AomEncoder) -> bytes: return output_buffer -def destroy_aom_encoder(aom_encoder : AomEncoder) -> None: +def destroy(aom_encoder : AomEncoder) -> None: aom_library = aom_module.create_static_library() if aom_library: diff --git a/facefusion/codecs/opus_decoder.py b/facefusion/codecs/opus_decoder.py new file mode 100644 index 00000000..3691f436 --- /dev/null +++ b/facefusion/codecs/opus_decoder.py @@ -0,0 +1,36 @@ +import ctypes +from typing import Optional + +from facefusion.libraries import opus as opus_module +from facefusion.types import OpusDecoder + + +def create(sample_rate : int, channel_total : int) -> Optional[OpusDecoder]: + opus_library = opus_module.create_static_library() + + if opus_library: + return opus_library.opus_decoder_create(sample_rate, channel_total, ctypes.byref(ctypes.c_int(0))) + + return None + + +def decode(opus_decoder : OpusDecoder, input_buffer : bytes, frame_size : int, channel_total : int) -> bytes: + opus_library = opus_module.create_static_library() + output_buffer = bytes() + + if opus_library: + input_total = len(input_buffer) + decode_buffer = (ctypes.c_float * (frame_size * channel_total))() + decode_length = opus_library.opus_decode_float(opus_decoder, input_buffer, input_total, decode_buffer, frame_size, 0) + + if decode_length: + output_buffer = ctypes.string_at(ctypes.addressof(decode_buffer), decode_length * channel_total * ctypes.sizeof(ctypes.c_float)) + + return output_buffer + + +def destroy(opus_decoder : OpusDecoder) -> None: + opus_library = opus_module.create_static_library() + + if opus_library: + opus_library.opus_decoder_destroy(opus_decoder) diff --git a/facefusion/codecs/opus.py b/facefusion/codecs/opus_encoder.py similarity index 78% rename from facefusion/codecs/opus.py rename to facefusion/codecs/opus_encoder.py index b34bcb6b..3e431db7 100644 --- a/facefusion/codecs/opus.py +++ b/facefusion/codecs/opus_encoder.py @@ -5,7 +5,7 @@ from facefusion.libraries import opus as opus_module from facefusion.types import OpusEncoder -def create_opus_encoder(sample_rate : int, channel_total : int) -> Optional[OpusEncoder]: +def create(sample_rate : int, channel_total : int) -> Optional[OpusEncoder]: opus_library = opus_module.create_static_library() if opus_library: @@ -14,7 +14,7 @@ def create_opus_encoder(sample_rate : int, channel_total : int) -> Optional[Opus return None -def encode_opus_buffer(opus_encoder : OpusEncoder, input_buffer : bytes, frame_size : int) -> bytes: +def encode(opus_encoder : OpusEncoder, input_buffer : bytes, frame_size : int) -> bytes: opus_library = opus_module.create_static_library() output_buffer = bytes() @@ -29,7 +29,7 @@ def encode_opus_buffer(opus_encoder : OpusEncoder, input_buffer : bytes, frame_s return output_buffer -def destroy_opus_encoder(opus_encoder : OpusEncoder) -> None: +def destroy(opus_encoder : OpusEncoder) -> None: opus_library = opus_module.create_static_library() if opus_library: diff --git a/facefusion/codecs/vpx_decoder.py b/facefusion/codecs/vpx_decoder.py new file mode 100644 index 00000000..51619a53 --- /dev/null +++ b/facefusion/codecs/vpx_decoder.py @@ -0,0 +1,82 @@ +import ctypes +from typing import Optional + +from facefusion.libraries import vpx as vpx_module +from facefusion.types import Resolution, VpxDecoder + + +def create() -> Optional[VpxDecoder]: + vpx_library = vpx_module.create_static_library() + + if vpx_library: + vpx_decoder = ctypes.create_string_buffer(64) + vpx_codec = ctypes.c_void_p.in_dll(vpx_library, 'vpx_codec_vp8_dx_algo') + + if vpx_library.vpx_codec_dec_init_ver(vpx_decoder, ctypes.byref(vpx_codec), None, 0, 12) == 0: + return vpx_decoder + + return None + + +#TODO: needs review +def decode(vpx_decoder : VpxDecoder, input_buffer : bytes) -> bytes: + vpx_library = vpx_module.create_static_library() + output_buffer = bytes() + + if vpx_library and input_buffer: + input_total = len(input_buffer) + temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer) + + if vpx_library.vpx_codec_decode(vpx_decoder, temp_buffer, input_total, None, 0) == 0: + frame_pointer = vpx_library.vpx_codec_get_frame(vpx_decoder, ctypes.byref(ctypes.c_void_p(0))) + + if frame_pointer: + output_buffer = collect(frame_pointer) + + return output_buffer + + +#TODO: needs review - find better name +def collect(frame_pointer : int) -> bytes: + frame_width = ctypes.c_uint.from_address(frame_pointer + 24).value & ~1 + frame_height = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1 + planes_offset = frame_pointer + 48 + strides_offset = frame_pointer + 80 + output_buffer = bytes() + + for index in range(3): + plane_pointer = ctypes.c_void_p.from_address(planes_offset + index * 8).value + stride = ctypes.c_int.from_address(strides_offset + index * 4).value + plane_width = frame_width >> (index > 0) + plane_height = frame_height >> (index > 0) + + for row in range(plane_height): + output_buffer += ctypes.string_at(plane_pointer + row * stride, plane_width) + + return output_buffer + + +#TODO: needs review +def read_resolution(vpx_decoder : VpxDecoder, input_buffer : bytes) -> Optional[Resolution]: + vpx_library = vpx_module.create_static_library() + + if vpx_library and input_buffer: + input_total = len(input_buffer) + temp_buffer = (ctypes.c_uint8 * input_total).from_buffer_copy(input_buffer) + + if vpx_library.vpx_codec_decode(vpx_decoder, temp_buffer, input_total, None, 0) == 0: + frame_pointer = vpx_library.vpx_codec_get_frame(vpx_decoder, ctypes.byref(ctypes.c_void_p(0))) + + if frame_pointer: + frame_width = ctypes.c_uint.from_address(frame_pointer + 24).value & ~1 + frame_height = ctypes.c_uint.from_address(frame_pointer + 28).value & ~1 + return frame_width, frame_height + + return None + + +def destroy(vpx_decoder : VpxDecoder) -> None: + vpx_library = vpx_module.create_static_library() + + if vpx_library: + vpx_library.vpx_codec_destroy(vpx_decoder) diff --git a/facefusion/codecs/vpx.py b/facefusion/codecs/vpx_encoder.py similarity index 85% rename from facefusion/codecs/vpx.py rename to facefusion/codecs/vpx_encoder.py index 2f76a5bf..8a8e7fa7 100644 --- a/facefusion/codecs/vpx.py +++ b/facefusion/codecs/vpx_encoder.py @@ -6,7 +6,7 @@ from facefusion.libraries import vpx as vpx_module from facefusion.types import BitRate, Resolution, VpxEncoder -def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[VpxEncoder]: +def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[VpxEncoder]: vpx_library = vpx_module.create_static_library() if vpx_library: @@ -37,7 +37,7 @@ def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, fram return None -def encode_vpx_buffer(vpx_encoder : VpxEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes: +def encode(vpx_encoder : VpxEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes: vpx_library = vpx_module.create_static_library() output_buffer = bytes() @@ -46,12 +46,12 @@ def encode_vpx_buffer(vpx_encoder : VpxEncoder, input_buffer : bytes, frame_reso encode_buffer = ctypes.create_string_buffer(input_buffer) if vpx_library.vpx_img_wrap(temp_buffer, 0x102, frame_resolution[0], frame_resolution[1], 1, encode_buffer) and vpx_library.vpx_codec_encode(vpx_encoder, temp_buffer, frame_index, 1, 0, 1) == 0: - output_buffer = collect_vpx_buffer(vpx_encoder) + output_buffer = collect(vpx_encoder) return output_buffer -def collect_vpx_buffer(vpx_encoder : VpxEncoder) -> bytes: +def collect(vpx_encoder : VpxEncoder) -> bytes: vpx_library = vpx_module.create_static_library() output_buffer = bytes() @@ -69,7 +69,7 @@ def collect_vpx_buffer(vpx_encoder : VpxEncoder) -> bytes: return output_buffer -def destroy_vpx_encoder(vpx_encoder : VpxEncoder) -> None: +def destroy(vpx_encoder : VpxEncoder) -> None: vpx_library = vpx_module.create_static_library() if vpx_library: diff --git a/facefusion/libraries/aom.py b/facefusion/libraries/aom.py index 2fff393c..e176b88e 100644 --- a/facefusion/libraries/aom.py +++ b/facefusion/libraries/aom.py @@ -119,4 +119,13 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.aom_codec_control.argtypes = [ ctypes.c_void_p, ctypes.c_int, ctypes.c_int ] library.aom_codec_control.restype = ctypes.c_int + library.aom_codec_dec_init_ver.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_long, ctypes.c_int ] + library.aom_codec_dec_init_ver.restype = ctypes.c_int + + library.aom_codec_decode.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p ] + library.aom_codec_decode.restype = ctypes.c_int + + library.aom_codec_get_frame.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ] + library.aom_codec_get_frame.restype = ctypes.c_void_p + return library diff --git a/facefusion/libraries/datachannel.py b/facefusion/libraries/datachannel.py index ff4b3064..5e8699cb 100644 --- a/facefusion/libraries/datachannel.py +++ b/facefusion/libraries/datachannel.py @@ -22,8 +22,8 @@ def create_static_library_set() -> Optional[LibrarySet]: }, 'datachannel': { - 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel.hash'), - 'path': resolve_relative_path('../.libraries/libdatachannel.hash') + 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel_next.hash'), + 'path': resolve_relative_path('../.libraries/libdatachannel_next.hash') }, 'ssl': { @@ -40,8 +40,8 @@ def create_static_library_set() -> Optional[LibrarySet]: }, 'datachannel': { - 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel.so'), - 'path': resolve_relative_path('../.libraries/libdatachannel.so') + 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'linux/libdatachannel_next.so'), + 'path': resolve_relative_path('../.libraries/libdatachannel_next.so') }, 'ssl': { @@ -62,8 +62,8 @@ def create_static_library_set() -> Optional[LibrarySet]: }, 'datachannel': { - 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel.hash'), - 'path': resolve_relative_path('../.libraries/libdatachannel.hash') + 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel_next.hash'), + 'path': resolve_relative_path('../.libraries/libdatachannel_next.hash') }, 'ssl': { @@ -80,8 +80,8 @@ def create_static_library_set() -> Optional[LibrarySet]: }, 'datachannel': { - 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel.dylib'), - 'path': resolve_relative_path('../.libraries/libdatachannel.dylib') + 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'macos/libdatachannel_next.dylib'), + 'path': resolve_relative_path('../.libraries/libdatachannel_next.dylib') }, 'ssl': { @@ -102,8 +102,8 @@ def create_static_library_set() -> Optional[LibrarySet]: }, 'datachannel': { - 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel.hash'), - 'path': resolve_relative_path('../.libraries/datachannel.hash') + 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel_next.hash'), + 'path': resolve_relative_path('../.libraries/datachannel_next.hash') }, 'ssl': { @@ -120,8 +120,8 @@ def create_static_library_set() -> Optional[LibrarySet]: }, 'datachannel': { - 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel.dll'), - 'path': resolve_relative_path('../.libraries/datachannel.dll') + 'url': resolve_download_url_by_provider('huggingface', 'libraries-4.0.0', 'windows/datachannel_next.dll'), + 'path': resolve_relative_path('../.libraries/datachannel_next.dll') }, 'ssl': { @@ -166,7 +166,7 @@ def create_static_library() -> Optional[ctypes.CDLL]: def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcInitLogger.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) ] library.rtcInitLogger.restype = None - library.rtcInitLogger(4, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)(0)) + library.rtcInitLogger(5, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)(0)) library.rtcCreatePeerConnection.restype = ctypes.c_int @@ -204,6 +204,20 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcSetOpusPacketizer.restype = ctypes.c_int + library.rtcGetPayloadTypesForCodec.argtypes = [ ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int ] + library.rtcGetPayloadTypesForCodec.restype = ctypes.c_int + + library.rtcSetAV1Depacketizer.argtypes = [ ctypes.c_int, ctypes.c_int ] + library.rtcSetAV1Depacketizer.restype = ctypes.c_int + library.rtcSetVP8Depacketizer.restype = ctypes.c_int + library.rtcSetOpusDepacketizer.restype = ctypes.c_int + + library.rtcChainRtcpReceivingSession.argtypes = [ ctypes.c_int ] + library.rtcChainRtcpReceivingSession.restype = ctypes.c_int + + library.rtcReceiveMessage.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int) ] + library.rtcReceiveMessage.restype = ctypes.c_int + return library diff --git a/facefusion/libraries/opus.py b/facefusion/libraries/opus.py index 7e67450d..ccb2052d 100644 --- a/facefusion/libraries/opus.py +++ b/facefusion/libraries/opus.py @@ -109,4 +109,13 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.opus_encoder_destroy.argtypes = [ ctypes.c_void_p ] library.opus_encoder_destroy.restype = None + library.opus_decoder_create.argtypes = [ ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int) ] + library.opus_decoder_create.restype = ctypes.c_void_p + + library.opus_decode_float.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int ] + library.opus_decode_float.restype = ctypes.c_int + + library.opus_decoder_destroy.argtypes = [ ctypes.c_void_p ] + library.opus_decoder_destroy.restype = None + return library diff --git a/facefusion/libraries/vpx.py b/facefusion/libraries/vpx.py index b73deaed..c5a91fa4 100644 --- a/facefusion/libraries/vpx.py +++ b/facefusion/libraries/vpx.py @@ -119,4 +119,13 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.vpx_codec_control_.argtypes = [ ctypes.c_void_p, ctypes.c_int, ctypes.c_int ] library.vpx_codec_control_.restype = ctypes.c_int + library.vpx_codec_dec_init_ver.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_long, ctypes.c_int ] + library.vpx_codec_dec_init_ver.restype = ctypes.c_int + + library.vpx_codec_decode.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.c_void_p, ctypes.c_long ] + library.vpx_codec_decode.restype = ctypes.c_int + + library.vpx_codec_get_frame.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ] + library.vpx_codec_get_frame.restype = ctypes.c_void_p + return library diff --git a/facefusion/rtc.py b/facefusion/rtc.py index fb66e185..989081ec 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -2,7 +2,7 @@ import ctypes from typing import List, Optional from facefusion.libraries import datachannel as datachannel_module -from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpMedia, SdpOffer, VideoCodec +from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcTrackInit, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec def create_peer_connection() -> PeerConnection: @@ -47,36 +47,32 @@ def set_remote_description(peer_connection : PeerConnection, sdp_offer : SdpOffe return None -def send_audio_to_peers(rtc_peers : List[RtcPeer], audio_buffer : bytes, audio_timestamp : int) -> None: +def send_video(rtc_peer : RtcPeer, video_buffer : bytes, video_timestamp : int) -> None: datachannel_library = datachannel_module.create_static_library() - if rtc_peers: - send_buffer = ctypes.create_string_buffer(audio_buffer) - send_total = len(audio_buffer) + if rtc_peer.get('video'): + video_track = rtc_peer.get('video').get('sender_track') - for rtc_peer in rtc_peers: - audio_track = rtc_peer.get('audio_track') - - if datachannel_library.rtcIsOpen(audio_track): - datachannel_library.rtcSetTrackRtpTimestamp(audio_track, audio_timestamp) - datachannel_library.rtcSendMessage(audio_track, send_buffer, send_total) + if datachannel_library.rtcIsOpen(video_track): + send_buffer = ctypes.create_string_buffer(video_buffer) + send_total = len(video_buffer) + datachannel_library.rtcSetTrackRtpTimestamp(video_track, video_timestamp) + datachannel_library.rtcSendMessage(video_track, send_buffer, send_total) return None -def send_video_to_peers(rtc_peers : List[RtcPeer], video_buffer : bytes, video_timestamp : int) -> None: +def send_audio(rtc_peer : RtcPeer, audio_buffer : bytes, audio_timestamp : int) -> None: datachannel_library = datachannel_module.create_static_library() - if rtc_peers: - send_buffer = ctypes.create_string_buffer(video_buffer) - send_total = len(video_buffer) + if rtc_peer.get('audio'): + audio_track = rtc_peer.get('audio').get('sender_track') - for rtc_peer in rtc_peers: - video_track = rtc_peer.get('video_track') - - if datachannel_library.rtcIsOpen(video_track): - datachannel_library.rtcSetTrackRtpTimestamp(video_track, video_timestamp) - datachannel_library.rtcSendMessage(video_track, send_buffer, send_total) + if datachannel_library.rtcIsOpen(audio_track): + send_buffer = ctypes.create_string_buffer(audio_buffer) + send_total = len(audio_buffer) + datachannel_library.rtcSetTrackRtpTimestamp(audio_track, audio_timestamp) + datachannel_library.rtcSendMessage(audio_track, send_buffer, send_total) return None @@ -98,16 +94,29 @@ def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDir audio_track_init = create_audio_track_init(media_direction, audio_codec, payload_type) audio_track = datachannel_library.rtcAddTrackEx(peer_connection, audio_track_init) - audio_packetizer = datachannel_module.define_rtc_packetizer_init() - audio_packetizer.ssrc = 43 - audio_packetizer.cname = b'audio' - audio_packetizer.payloadType = payload_type - audio_packetizer.clockRate = 48000 + if media_direction == 'sendonly': + audio_packetizer = datachannel_module.define_rtc_packetizer_init() + audio_packetizer.ssrc = 43 + audio_packetizer.cname = b'audio' + audio_packetizer.payloadType = payload_type + audio_packetizer.clockRate = 48000 - if audio_codec == 'opus': - datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer)) + if audio_codec == 'opus': + datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer)) - datachannel_library.rtcChainRtcpSrReporter(audio_track) + datachannel_library.rtcChainRtcpSrReporter(audio_track) + + if media_direction == 'recvonly': + audio_depacketizer = datachannel_module.define_rtc_packetizer_init() + audio_depacketizer.ssrc = 0 + audio_depacketizer.cname = b'audio' + audio_depacketizer.payloadType = payload_type + audio_depacketizer.clockRate = 48000 + + if audio_codec == 'opus': + datachannel_library.rtcSetOpusDepacketizer(audio_track, ctypes.byref(audio_depacketizer)) + + datachannel_library.rtcChainRtcpReceivingSession(audio_track) return audio_track @@ -117,86 +126,102 @@ def add_video_track(peer_connection : PeerConnection, media_direction : MediaDir video_track_init = create_video_track_init(media_direction, video_codec, payload_type) video_track = datachannel_library.rtcAddTrackEx(peer_connection, video_track_init) - video_packetizer = datachannel_module.define_rtc_packetizer_init() - video_packetizer.ssrc = 42 - video_packetizer.cname = b'video' - video_packetizer.payloadType = payload_type - video_packetizer.clockRate = 90000 - video_packetizer.maxFragmentSize = 1200 + if media_direction == 'sendonly': + video_packetizer = datachannel_module.define_rtc_packetizer_init() + video_packetizer.ssrc = 42 + video_packetizer.cname = b'video' + video_packetizer.payloadType = payload_type + video_packetizer.clockRate = 90000 + video_packetizer.maxFragmentSize = 1200 - if video_codec == 'av1': - video_packetizer.obuPacketization = 1 - datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer)) + if video_codec == 'av1': + video_packetizer.obuPacketization = 1 + datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer)) - if video_codec == 'vp8': - datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer)) + if video_codec == 'vp8': + datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer)) - datachannel_library.rtcChainRtcpSrReporter(video_track) - datachannel_library.rtcChainRtcpNackResponder(video_track, 512) + datachannel_library.rtcChainRtcpSrReporter(video_track) + datachannel_library.rtcChainRtcpNackResponder(video_track, 512) + + if media_direction == 'recvonly': + if video_codec == 'av1': + datachannel_library.rtcSetAV1Depacketizer(video_track, 1) + + if video_codec == 'vp8': + video_depacketizer = datachannel_module.define_rtc_packetizer_init() + video_depacketizer.ssrc = 0 + video_depacketizer.cname = b'video' + video_depacketizer.payloadType = payload_type + video_depacketizer.clockRate = 90000 + datachannel_library.rtcSetVP8Depacketizer(video_track, ctypes.byref(video_depacketizer)) + + datachannel_library.rtcChainRtcpReceivingSession(video_track) return video_track def create_audio_track_init(media_direction : MediaDirection, audio_codec : AudioCodec, payload_type : int) -> RtcTrackInit: track_init = datachannel_module.define_rtc_track_init() + track_init.name = b'audio' + track_init.payloadType = payload_type if media_direction == 'sendonly': track_init.direction = 1 + track_init.mid = b'3' + track_init.ssrc = 43 + if media_direction == 'recvonly': track_init.direction = 2 + track_init.mid = b'2' + track_init.ssrc = 45 + + if media_direction == 'sendrecv': + track_init.direction = 3 + track_init.mid = b'1' + track_init.ssrc = 43 + if audio_codec == 'opus': track_init.codec = 128 - track_init.payloadType = payload_type - track_init.ssrc = 43 - track_init.name = b'audio' - track_init.mid = b'1' - return ctypes.byref(track_init) def create_video_track_init(media_direction : MediaDirection, video_codec : VideoCodec, payload_type : int) -> RtcTrackInit: track_init = datachannel_module.define_rtc_track_init() + track_init.name = b'video' + track_init.payloadType = payload_type if media_direction == 'sendonly': track_init.direction = 1 + track_init.mid = b'1' + track_init.ssrc = 42 + if media_direction == 'recvonly': track_init.direction = 2 + track_init.mid = b'0' + track_init.ssrc = 44 + + if media_direction == 'sendrecv': + track_init.direction = 3 + track_init.mid = b'0' + track_init.ssrc = 42 + if video_codec == 'av1': track_init.codec = 4 + if video_codec == 'vp8': track_init.codec = 1 - track_init.payloadType = payload_type - track_init.ssrc = 42 - track_init.name = b'video' - track_init.mid = b'0' - return ctypes.byref(track_init) -def detect_sdp_media(sdp_offer : SdpOffer) -> SdpMedia: - sdp_media : SdpMedia = {} +def get_payload_type(sdp_offer : SdpOffer, codec : AudioCodec | VideoCodec) -> int: + datachannel_library = datachannel_module.create_static_library() + payload_type_buffer = (ctypes.c_int * 16)() + payload_type_total = datachannel_library.rtcGetPayloadTypesForCodec(sdp_offer.encode(), codec.lower().encode(), payload_type_buffer, 16) - for line in sdp_offer.splitlines(): - if line.startswith('a=rtpmap:'): - if 'av1/90000' in line.lower(): - sdp_media['video'] =\ - { - 'codec': 'av1', - 'payload_type': int(line.removeprefix('a=rtpmap:').split()[0]) - } - if 'vp8/90000' in line.lower(): - sdp_media['video'] =\ - { - 'codec': 'vp8', - 'payload_type': int(line.removeprefix('a=rtpmap:').split()[0]) - } - if 'opus/48000/2' in line.lower(): - sdp_media['audio'] =\ - { - 'codec': 'opus', - 'payload_type': int(line.removeprefix('a=rtpmap:').split()[0]) - } + if payload_type_total: + return payload_type_buffer[0] - return sdp_media + return 0 diff --git a/facefusion/rtc_store.py b/facefusion/rtc_store.py index 2fb9bcbb..eeb1ae4b 100644 --- a/facefusion/rtc_store.py +++ b/facefusion/rtc_store.py @@ -3,7 +3,6 @@ from typing import List from facefusion import rtc from facefusion.types import RtcPeer, RtcStore, SessionId - RTC_STORE : RtcStore = {} diff --git a/facefusion/streamer.py b/facefusion/streamer.py index 026e1102..e8cf1b75 100644 --- a/facefusion/streamer.py +++ b/facefusion/streamer.py @@ -14,7 +14,7 @@ from facefusion.content_analyser import analyse_stream from facefusion.ffmpeg import open_ffmpeg from facefusion.filesystem import is_directory from facefusion.processors.core import get_processors_modules -from facefusion.types import Fps, StreamMode, VisionFrame +from facefusion.types import AudioFrame, Fps, StreamMode, VisionFrame from facefusion.vision import extract_vision_mask, read_static_images @@ -31,7 +31,8 @@ def multi_process_capture(camera_capture : cv2.VideoCapture, camera_fps : Fps) - camera_capture.release() if numpy.any(capture_frame): - future = executor.submit(process_vision_frame, capture_frame) + audio_frame = create_empty_audio_frame() + future = executor.submit(process_frame, audio_frame, capture_frame) futures.append(future) for future_done in [ future for future in futures if future.done() ]: @@ -44,11 +45,10 @@ def multi_process_capture(camera_capture : cv2.VideoCapture, camera_fps : Fps) - yield capture_deque.popleft() -def process_vision_frame(target_vision_frame : VisionFrame) -> VisionFrame: +def process_frame(stream_audio_frame : AudioFrame, stream_vision_frame : VisionFrame) -> VisionFrame: source_vision_frames = read_static_images(state_manager.get_item('source_paths')) - source_audio_frame = create_empty_audio_frame() source_voice_frame = create_empty_audio_frame() - temp_vision_frame = target_vision_frame.copy() + temp_vision_frame = stream_vision_frame.copy() temp_vision_mask = extract_vision_mask(temp_vision_frame) for processor_module in get_processors_modules(state_manager.get_item('processors')): @@ -58,9 +58,9 @@ def process_vision_frame(target_vision_frame : VisionFrame) -> VisionFrame: temp_vision_frame, temp_vision_mask = processor_module.process_frame( { 'source_vision_frames': source_vision_frames, - 'source_audio_frame': source_audio_frame, + 'source_audio_frame': stream_audio_frame, 'source_voice_frame': source_voice_frame, - 'target_vision_frame': target_vision_frame, + 'target_vision_frame': stream_vision_frame, 'temp_vision_frame': temp_vision_frame, 'temp_vision_mask': temp_vision_mask }) diff --git a/facefusion/types.py b/facefusion/types.py index f31906c9..a4af9310 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -94,8 +94,11 @@ AudioCodec : TypeAlias = Literal['opus'] VideoCodec : TypeAlias = Literal['av1', 'vp8'] AomEncoder : TypeAlias = ctypes.Array[ctypes.c_char] +AomDecoder : TypeAlias = ctypes.Array[ctypes.c_char] OpusEncoder : TypeAlias = ctypes.c_void_p +OpusDecoder : TypeAlias = ctypes.c_void_p VpxEncoder : TypeAlias = ctypes.Array[ctypes.c_char] +VpxDecoder : TypeAlias = ctypes.Array[ctypes.c_char] BitRate : TypeAlias = int SampleRate : TypeAlias = int @@ -274,18 +277,32 @@ StreamMode = Literal['udp', 'v4l2'] PeerConnection : TypeAlias = int SdpOffer : TypeAlias = str SdpAnswer : TypeAlias = str -MediaDirection : TypeAlias = Literal['sendonly', 'recvonly'] +MediaDirection : TypeAlias = Literal['sendonly', 'recvonly', 'sendrecv'] RtcTrackInit : TypeAlias = Any RtcVideoTrack : TypeAlias = int RtcAudioTrack : TypeAlias = int +RtcPeerAudio = TypedDict('RtcPeerAudio', +{ + 'sender_track': RtcAudioTrack, + 'receiver_track': RtcAudioTrack, + 'codec': AudioCodec, +}) + +RtcPeerVideo = TypedDict('RtcPeerVideo', +{ + 'sender_track': RtcVideoTrack, + 'receiver_track': RtcVideoTrack, + 'codec': VideoCodec, +}) + RtcPeer = TypedDict('RtcPeer', { 'peer_connection': PeerConnection, - 'video_track': RtcVideoTrack, - 'audio_track': RtcAudioTrack, + 'audio': NotRequired[RtcPeerAudio], + 'video': RtcPeerVideo, }) RtcStore : TypeAlias = Dict[SessionId, List[RtcPeer]] diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index bb98ba91..804cc783 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -1,19 +1,18 @@ import tempfile -import threading -from functools import partial from typing import Iterator from unittest.mock import patch import pytest from starlette.testclient import TestClient -from facefusion import metadata, rtc, session_manager, state_manager +from facefusion import metadata, rtc, rtc_store, session_manager, state_manager from facefusion.apis import asset_store from facefusion.apis.core import create_api from facefusion.core import common_pre_check from facefusion.download import conditional_download from facefusion.hash_helper import create_hash from facefusion.libraries import datachannel as datachannel_module +from facefusion.types import VideoCodec from .assert_helper import get_test_example_file, get_test_examples_directory @@ -37,6 +36,7 @@ def before_all() -> None: def before_each() -> None: session_manager.SESSIONS.clear() asset_store.clear() + rtc_store.clear() @pytest.fixture(scope = 'module') @@ -45,16 +45,6 @@ def test_client() -> Iterator[TestClient]: yield test_client -@pytest.fixture(scope = 'function') -def create_event() -> threading.Event: - return threading.Event() - - -@pytest.mark.helper -def set_event(session_id : str, media_buffer : bytes, timestamp : int, event : threading.Event) -> None: - event.set() - - def test_stream_image(test_client : TestClient) -> None: create_session_response = test_client.post('/session', json = { @@ -85,7 +75,7 @@ def test_stream_image(test_client : TestClient) -> None: assert select_response.status_code == 200 - with test_client.websocket_connect('/stream?mode=image', subprotocols = + with test_client.websocket_connect('/stream?type=image&action=process', subprotocols = [ 'access_token.' + access_token ]) as websocket: @@ -96,7 +86,7 @@ def test_stream_image(test_client : TestClient) -> None: @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) -def test_stream_video(test_client : TestClient, create_event : threading.Event, video_codec : str) -> None: +def test_stream_video(test_client : TestClient, video_codec : VideoCodec) -> None: create_session_response = test_client.post('/session', json = { 'client_version': metadata.get('version') @@ -124,27 +114,23 @@ def test_stream_video(test_client : TestClient, create_event : threading.Event, 'Authorization': 'Bearer ' + access_token }) - with patch('facefusion.rtc.send_video_to_peers', side_effect = partial(set_event, event = create_event)): - with test_client.websocket_connect('/stream?mode=video&codec=' + video_codec, subprotocols = - [ - 'access_token.' + access_token - ]) as websocket: - websocket.send_bytes(chr(1).encode() + source_content) - websocket.receive_text() + peer_connection = rtc.create_peer_connection() - peer_connection = rtc.create_peer_connection() - rtc.add_video_track(peer_connection, 'recvonly', 'vp8', 96) - rtc.add_audio_track(peer_connection, 'recvonly', 'opus', 111) - sdp_offer = rtc.create_sdp_offer(peer_connection) - datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) - stream_response = test_client.post('/stream', content = sdp_offer, headers = - { - 'Authorization': 'Bearer ' + access_token, - 'Content-Type': 'application/sdp' - }) + if video_codec == 'av1': + rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 35) + if video_codec == 'vp8': + rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 96) - assert stream_response.status_code == 201 + rtc.add_audio_track(peer_connection, 'sendrecv', 'opus', 111) + sdp_offer = rtc.create_sdp_offer(peer_connection) + datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) - create_event.wait(timeout = 10) + with patch('facefusion.rtc.send_video'): + stream_response = test_client.post('/stream?type=video&action=process', content = sdp_offer, headers = + { + 'Authorization': 'Bearer ' + access_token, + 'Content-Type': 'application/sdp' + }) - assert create_event.is_set() + assert stream_response.status_code == 201 + assert 'm=video' in stream_response.text diff --git a/tests/test_codec_aom_decoder.py b/tests/test_codec_aom_decoder.py new file mode 100644 index 00000000..00838293 --- /dev/null +++ b/tests/test_codec_aom_decoder.py @@ -0,0 +1,61 @@ +from unittest.mock import patch + +import cv2 +import pytest +from tests.assert_helper import get_test_example_file, get_test_examples_directory + +from facefusion import state_manager +from facefusion.codecs.aom_decoder import create, decode, destroy, read_resolution +from facefusion.codecs.aom_encoder import create as create_encoder, encode +from facefusion.download import conditional_download +from facefusion.libraries import aom as aom_module +from facefusion.vision import read_video_frame + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) + + conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' ]) + + aom_module.pre_check() + + +#TODO: needs review +def test_create() -> None: + assert create() + + +#TODO: needs review +def test_decode() -> None: + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) + aom_encoder = create_encoder(video_resolution, 1000, 1, 0) + encoded_buffer = encode(aom_encoder, video_buffer, video_resolution, 0) + decode_resolution = read_resolution(create(), encoded_buffer) + + assert len(decode(create(), encoded_buffer)) == decode_resolution[0] * decode_resolution[1] * 3 // 2 + assert decode(create(), bytes()) == bytes() + + +#TODO: needs review +def test_read_resolution() -> None: + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) + aom_encoder = create_encoder(video_resolution, 1000, 1, 0) + encoded_buffer = encode(aom_encoder, video_buffer, video_resolution, 0) + + assert read_resolution(create(), encoded_buffer)[0] >= video_resolution[0] + assert read_resolution(create(), encoded_buffer)[1] >= video_resolution[1] + assert read_resolution(create(), bytes()) is None + + +#TODO: needs review +def test_destroy() -> None: + aom_decoder = create() + + with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock: + destroy(aom_decoder) + mock.assert_called_once_with(aom_decoder) diff --git a/tests/test_codec_aom.py b/tests/test_codec_aom_encoder.py similarity index 67% rename from tests/test_codec_aom.py rename to tests/test_codec_aom_encoder.py index 918f1a50..f0d2b1a3 100644 --- a/tests/test_codec_aom.py +++ b/tests/test_codec_aom_encoder.py @@ -5,7 +5,7 @@ import pytest from tests.assert_helper import get_test_example_file, get_test_examples_directory from facefusion import state_manager -from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encode_aom_buffer +from facefusion.codecs.aom_encoder import create, destroy, encode from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download from facefusion.hash_helper import create_hash @@ -22,27 +22,27 @@ def before_all() -> None: aom_module.pre_check() -def test_create_aom_encoder() -> None: - assert create_aom_encoder(1000, 8, 16, (320, 240)) - assert create_aom_encoder(0, 0, 0, (0, 0)) is None +def test_create() -> None: + assert create((320, 240), 1000, 8, 16) + assert create((0, 0), 0, 0, 0) is None -def test_encode_aom_buffer() -> None: +def test_encode() -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) - aom_encoder = create_aom_encoder(1000, 1, 0, video_resolution) + aom_encoder = create(video_resolution, 1000, 1, 0) if is_linux() or is_windows(): - assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31' + assert create_hash(encode(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31' if is_macos(): pytest.skip() -def test_destroy_aom_encoder() -> None: - aom_encoder = create_aom_encoder(1000, 8, 16, (320, 240)) +def test_destroy() -> None: + aom_encoder = create((320, 240), 1000, 8, 16) with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock: - destroy_aom_encoder(aom_encoder) + destroy(aom_encoder) mock.assert_called_once_with(aom_encoder) diff --git a/tests/test_codec_opus_decoder.py b/tests/test_codec_opus_decoder.py new file mode 100644 index 00000000..191e238e --- /dev/null +++ b/tests/test_codec_opus_decoder.py @@ -0,0 +1,47 @@ +from unittest.mock import patch + +import numpy +import pytest +from tests.assert_helper import get_test_example_file, get_test_examples_directory + +from facefusion import state_manager +from facefusion.codecs.opus_decoder import create, decode, destroy +from facefusion.codecs.opus_encoder import create as create_encoder, encode +from facefusion.download import conditional_download +from facefusion.ffmpeg import read_audio_buffer +from facefusion.libraries import opus as opus_module + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) + + conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3' ]) + + opus_module.pre_check() + + +#TODO: needs review +def test_create() -> None: + assert create(48000, 2) + assert create(0, 0) is None + + +#TODO: needs review +def test_decode() -> None: + audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) + audio_sample = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 + opus_encoder = create_encoder(48000, 2) + encoded_buffer = encode(opus_encoder, audio_sample.tobytes(), 960) + opus_decoder = create(48000, 2) + + assert len(decode(opus_decoder, encoded_buffer, 960, 2)) == 960 * 2 * 4 + + +#TODO: needs review +def test_destroy() -> None: + opus_decoder = create(48000, 2) + + with patch.object(opus_module.create_static_library(), 'opus_decoder_destroy') as mock: + destroy(opus_decoder) + mock.assert_called_once_with(opus_decoder) diff --git a/tests/test_codec_opus.py b/tests/test_codec_opus_encoder.py similarity index 69% rename from tests/test_codec_opus.py rename to tests/test_codec_opus_encoder.py index d2915735..48622bb9 100644 --- a/tests/test_codec_opus.py +++ b/tests/test_codec_opus_encoder.py @@ -5,7 +5,7 @@ import pytest from tests.assert_helper import get_test_example_file, get_test_examples_directory from facefusion import state_manager -from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer +from facefusion.codecs.opus_encoder import create, destroy, encode from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download from facefusion.ffmpeg import read_audio_buffer @@ -22,26 +22,26 @@ def before_all() -> None: opus_module.pre_check() -def test_create_opus_encoder() -> None: - assert create_opus_encoder(48000, 2) - assert create_opus_encoder(0, 0) is None +def test_create() -> None: + assert create(48000, 2) + assert create(0, 0) is None -def test_encode_opus_buffer() -> None: +def test_encode() -> None: audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) audio_sample = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 - opus_encoder = create_opus_encoder(48000, 2) + opus_encoder = create(48000, 2) if is_linux() or is_windows(): - assert create_hash(encode_opus_buffer(opus_encoder, audio_sample.tobytes(), 960)) == '8abe71cf' + assert create_hash(encode(opus_encoder, audio_sample.tobytes(), 960)) == '8abe71cf' if is_macos(): pytest.skip() -def test_destroy_opus_encoder() -> None: - opus_encoder = create_opus_encoder(48000, 2) +def test_destroy() -> None: + opus_encoder = create(48000, 2) with patch.object(opus_module.create_static_library(), 'opus_encoder_destroy') as mock: - destroy_opus_encoder(opus_encoder) + destroy(opus_encoder) mock.assert_called_once_with(opus_encoder) diff --git a/tests/test_codec_vpx_decoder.py b/tests/test_codec_vpx_decoder.py new file mode 100644 index 00000000..64ef0ed0 --- /dev/null +++ b/tests/test_codec_vpx_decoder.py @@ -0,0 +1,61 @@ +from unittest.mock import patch + +import cv2 +import pytest +from tests.assert_helper import get_test_example_file, get_test_examples_directory + +from facefusion import state_manager +from facefusion.codecs.vpx_decoder import create, decode, destroy, read_resolution +from facefusion.codecs.vpx_encoder import create as create_encoder, encode +from facefusion.download import conditional_download +from facefusion.libraries import vpx as vpx_module +from facefusion.vision import read_video_frame + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) + + conditional_download(get_test_examples_directory(), [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' ]) + + vpx_module.pre_check() + + +#TODO: needs review +def test_create() -> None: + assert create() + + +#TODO: needs review +def test_decode() -> None: + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) + vpx_encoder = create_encoder(video_resolution, 1000, 1, 0) + encoded_buffer = encode(vpx_encoder, video_buffer, video_resolution, 0) + vpx_decoder = create() + + assert len(decode(vpx_decoder, encoded_buffer)) == video_resolution[0] * video_resolution[1] * 3 // 2 + assert decode(vpx_decoder, bytes()) == bytes() + + +#TODO: needs review +def test_read_resolution() -> None: + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) + vpx_encoder = create_encoder(video_resolution, 1000, 1, 0) + encoded_buffer = encode(vpx_encoder, video_buffer, video_resolution, 0) + vpx_decoder = create() + + assert read_resolution(vpx_decoder, encoded_buffer) == video_resolution + assert read_resolution(vpx_decoder, bytes()) is None + + +#TODO: needs review +def test_destroy() -> None: + vpx_decoder = create() + + with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock: + destroy(vpx_decoder) + mock.assert_called_once_with(vpx_decoder) diff --git a/tests/test_codec_vpx.py b/tests/test_codec_vpx_encoder.py similarity index 67% rename from tests/test_codec_vpx.py rename to tests/test_codec_vpx_encoder.py index 93b25166..ec9aacca 100644 --- a/tests/test_codec_vpx.py +++ b/tests/test_codec_vpx_encoder.py @@ -5,7 +5,7 @@ import pytest from tests.assert_helper import get_test_example_file, get_test_examples_directory from facefusion import state_manager -from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer +from facefusion.codecs.vpx_encoder import create, destroy, encode from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download from facefusion.hash_helper import create_hash @@ -22,27 +22,27 @@ def before_all() -> None: vpx_module.pre_check() -def test_create_vpx_encoder() -> None: - assert create_vpx_encoder(1000, 8, 16, (320, 240)) - assert create_vpx_encoder(0, 0, 0, (0, 0)) is None +def test_create() -> None: + assert create((320, 240), 1000, 8, 16) + assert create((0, 0), 0, 0, 0) is None -def test_encode_vpx_buffer() -> None: +def test_encode() -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) - vpx_encoder = create_vpx_encoder(1000, 1, 0, video_resolution) + vpx_encoder = create(video_resolution, 1000, 1, 0) if is_linux() or is_windows(): - assert create_hash(encode_vpx_buffer(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f' + assert create_hash(encode(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f' if is_macos(): pytest.skip() -def test_destroy_vpx_encoder() -> None: - vpx_encoder = create_vpx_encoder(1000, 8, 16, (320, 240)) +def test_destroy() -> None: + vpx_encoder = create((320, 240), 1000, 8, 16) with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock: - destroy_vpx_encoder(vpx_encoder) + destroy(vpx_encoder) mock.assert_called_once_with(vpx_encoder) diff --git a/tests/test_rtc.py b/tests/test_rtc.py index 544aec94..a09f4b45 100644 --- a/tests/test_rtc.py +++ b/tests/test_rtc.py @@ -4,7 +4,7 @@ import pytest from facefusion import state_manager from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module -from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, detect_sdp_media, send_audio_to_peers, send_video_to_peers, set_remote_description +from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, get_payload_type, send_audio, send_video, set_remote_description from facefusion.types import RtcPeer @@ -57,48 +57,56 @@ def test_create_sdp_answer() -> None: assert 'm=video' in sdp_answer assert 'VP8/90000' in sdp_answer - assert 'a=ssrc:42 cname:video' in sdp_answer assert 'm=audio' in sdp_answer assert 'opus/48000/2' in sdp_answer - assert 'a=ssrc:43 cname:audio' in sdp_answer assert 'a=recvonly' in sdp_answer assert datachannel_library.rtcDeletePeerConnection(sender_peer_connection) == 0 assert datachannel_library.rtcDeletePeerConnection(receiver_peer_connection) == 0 -def test_send_audio_to_peers() -> None: +def test_send_video() -> None: datachannel_library = datachannel_module.create_static_library() peer_connection = create_peer_connection() - audio_track = add_audio_track(peer_connection, 'sendonly', 'opus', 111) - rtc_peers : List[RtcPeer] =\ - [ + video_track = add_video_track(peer_connection, 'sendonly', 'vp8', 96) + rtc_peer : RtcPeer =\ + { + 'peer_connection': peer_connection, + 'video': { - 'peer_connection': peer_connection, - 'video_track': 0, - 'audio_track': audio_track + 'sender_track': video_track, + 'receiver_track': video_track, + 'codec': 'vp8' } - ] + } - send_audio_to_peers(rtc_peers, bytes(960), 0) + send_video(rtc_peer, bytes(1024), 0) datachannel_library.rtcDeletePeerConnection(peer_connection) -def test_send_video_to_peers() -> None: +def test_send_audio() -> None: datachannel_library = datachannel_module.create_static_library() peer_connection = create_peer_connection() - video_track = add_video_track(peer_connection, 'sendonly', 'vp8', 96) - rtc_peers : List[RtcPeer] =\ - [ + audio_track = add_audio_track(peer_connection, 'sendonly', 'opus', 111) + rtc_peer : RtcPeer =\ + { + 'peer_connection': peer_connection, + 'video': { - 'peer_connection': peer_connection, - 'video_track': video_track, - 'audio_track': 0 + 'sender_track': 0, + 'receiver_track': 0, + 'codec': 'vp8' + }, + 'audio': + { + 'sender_track': audio_track, + 'receiver_track': audio_track, + 'codec': 'opus' } - ] + } - send_video_to_peers(rtc_peers, bytes(1024), 0) + send_audio(rtc_peer, bytes(960), 0) datachannel_library.rtcDeletePeerConnection(peer_connection) @@ -110,8 +118,12 @@ def test_delete_peers() -> None: [ { 'peer_connection': peer_connection, - 'video_track': 0, - 'audio_track': 0 + 'video': + { + 'sender_track': 0, + 'receiver_track': 0, + 'codec': 'vp8' + } } ] @@ -120,16 +132,14 @@ def test_delete_peers() -> None: assert datachannel_library.rtcDeletePeerConnection(peer_connection) == -1 -def test_detect_sdp_media() -> None: +def test_get_payload_type() -> None: peer_connection = create_peer_connection() add_video_track(peer_connection, 'sendonly', 'vp8', 96) add_audio_track(peer_connection, 'sendonly', 'opus', 111) sdp_offer = create_sdp_offer(peer_connection) - sdp_payload = detect_sdp_media(sdp_offer) - assert sdp_payload.get('video').get('codec') == 'vp8' - assert sdp_payload.get('video').get('payload_type') == 96 - assert sdp_payload.get('audio').get('codec') == 'opus' - assert sdp_payload.get('audio').get('payload_type') == 111 + assert get_payload_type(sdp_offer, 'vp8') == 96 + assert get_payload_type(sdp_offer, 'opus') == 111 + assert get_payload_type(sdp_offer, 'av1') == 0 datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py deleted file mode 100644 index c2244db9..00000000 --- a/tests/test_stream_helper.py +++ /dev/null @@ -1,238 +0,0 @@ -import asyncio -import queue -from typing import Any, Optional -from unittest.mock import AsyncMock, MagicMock, patch - -import cv2 -import numpy -import pytest -from numpy.typing import NDArray -from starlette.websockets import WebSocketState - -from facefusion.apis.stream_helper import encode_audio_loop, encode_video_loop, handle_video_stream -from facefusion.hash_helper import create_hash -from facefusion.types import VideoCodec, VisionFrame - - -def _make_handler_websocket(events : list[Any]) -> MagicMock: - mock = MagicMock() - mock.scope = {} - mock.client_state = WebSocketState.CONNECTED - mock.accept = AsyncMock() - mock.send_text = AsyncMock() - mock.close = AsyncMock() - mock.receive = AsyncMock(side_effect = events) - return mock - - -def _make_video_packet(frame : NDArray[Any]) -> bytes: - _, encoded = cv2.imencode('.jpg', frame) - return b'\x01' + encoded.tobytes() - - -def _make_audio_packet(samples : NDArray[Any]) -> bytes: - return b'\x02' + samples.tobytes() - - -@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) -def test_encode_video_loop(video_codec : VideoCodec) -> None: - frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) - small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) - large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8) - black_frame = numpy.zeros((64, 64, 3), dtype = numpy.uint8) - prefix = 'facefusion.apis.stream_helper.' - - create_name = prefix + 'create_aom_encoder' - encode_name = prefix + 'encode_aom_buffer' - destroy_name = prefix + 'destroy_aom_encoder' - - if video_codec == 'vp8': - create_name = prefix + 'create_vpx_encoder' - encode_name = prefix + 'encode_vpx_buffer' - destroy_name = prefix + 'destroy_vpx_encoder' - - vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue() - vision_frame_queue.put(frame) - vision_frame_queue.put(None) - with patch(prefix + 'process_vision_frame', return_value = frame), \ - patch(create_name, return_value = MagicMock()), \ - patch(encode_name, return_value = b'encoded'), \ - patch(destroy_name), \ - patch(prefix + 'rtc_store') as mock_rtc_store, \ - patch(prefix + 'rtc') as mock_rtc: - mock_rtc_store.get_peers.return_value = [ MagicMock() ] - encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) - mock_rtc.send_video_to_peers.assert_called_once() - - vision_frame_queue = queue.Queue() - vision_frame_queue.put(frame) - vision_frame_queue.put(frame) - vision_frame_queue.put(frame) - vision_frame_queue.put(None) - with patch(prefix + 'process_vision_frame', return_value = frame), \ - patch(create_name, return_value = MagicMock()), \ - patch(encode_name, return_value = b'encoded'), \ - patch(destroy_name), \ - patch(prefix + 'rtc_store') as mock_rtc_store, \ - patch(prefix + 'rtc') as mock_rtc: - mock_rtc_store.get_peers.return_value = [ MagicMock() ] - encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) - assert mock_rtc.send_video_to_peers.call_count == 3 - - vision_frame_queue = queue.Queue() - vision_frame_queue.put(black_frame) - with patch(create_name, return_value = MagicMock()), \ - patch(destroy_name) as mock_destroy, \ - patch(prefix + 'rtc') as mock_rtc: - encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) - mock_rtc.send_video_to_peers.assert_not_called() - mock_destroy.assert_called_once() - - vision_frame_queue = queue.Queue() - vision_frame_queue.put(frame) - vision_frame_queue.put(None) - with patch(prefix + 'process_vision_frame', return_value = frame), \ - patch(create_name, return_value = MagicMock()), \ - patch(encode_name, return_value = b''), \ - patch(destroy_name), \ - patch(prefix + 'rtc_store'), \ - patch(prefix + 'rtc') as mock_rtc: - encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) - mock_rtc.send_video_to_peers.assert_not_called() - - vision_frame_queue = queue.Queue() - vision_frame_queue.put(small_frame) - vision_frame_queue.put(None) - with patch(prefix + 'process_vision_frame', return_value = large_frame), \ - patch(create_name, return_value = MagicMock()) as mock_create, \ - patch(encode_name, return_value = b'encoded'), \ - patch(destroy_name) as mock_destroy, \ - patch(prefix + 'rtc_store') as mock_rtc_store, \ - patch(prefix + 'rtc') as mock_rtc: - mock_rtc_store.get_peers.return_value = [ MagicMock() ] - encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) - assert mock_create.call_count == 2 - assert mock_destroy.call_count == 2 - mock_rtc.send_video_to_peers.assert_called_once() - - vision_frame_queue = queue.Queue() - vision_frame_queue.put(frame) - vision_frame_queue.put(None) - with patch(create_name, return_value = None), \ - patch(prefix + 'rtc') as mock_rtc: - encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) - mock_rtc.send_video_to_peers.assert_not_called() - - -# TODO: refine test -def test_encode_audio_loop() -> None: - audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes() - - audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue() - audio_chunk_queue.put(audio_chunk) - audio_chunk_queue.put(None) - with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ - patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \ - patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ - patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \ - patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - mock_rtc_store.get_peers.return_value = [ MagicMock() ] - encode_audio_loop('opus', audio_chunk_queue, 'session-1') - mock_rtc.send_audio_to_peers.assert_called_once() - assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0 - - audio_chunk_queue = queue.Queue() - audio_chunk_queue.put(audio_chunk) - audio_chunk_queue.put(audio_chunk) - audio_chunk_queue.put(None) - with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ - patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b'encoded'), \ - patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ - patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \ - patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - mock_rtc_store.get_peers.return_value = [ MagicMock() ] - encode_audio_loop('opus', audio_chunk_queue, 'session-1') - assert mock_rtc.send_audio_to_peers.call_count == 2 - assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0 - assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960 - - audio_chunk_queue = queue.Queue() - audio_chunk_queue.put(audio_chunk) - audio_chunk_queue.put(None) - with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ - patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = b''), \ - patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ - patch('facefusion.apis.stream_helper.rtc_store'), \ - patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - encode_audio_loop('opus', audio_chunk_queue, 'session-1') - mock_rtc.send_audio_to_peers.assert_not_called() - - audio_chunk_queue = queue.Queue() - audio_chunk_queue.put(audio_chunk) - audio_chunk_queue.put(None) - with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ - patch('facefusion.apis.stream_helper.encode_opus_buffer', return_value = None), \ - patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \ - patch('facefusion.apis.stream_helper.rtc_store'), \ - patch('facefusion.apis.stream_helper.rtc'): - encode_audio_loop('opus', audio_chunk_queue, 'session-1') - mock_destroy.assert_called_once() - - audio_chunk_queue = queue.Queue() - audio_chunk_queue.put(b'') - with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ - patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \ - patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - encode_audio_loop('opus', audio_chunk_queue, 'session-1') - mock_rtc.send_audio_to_peers.assert_not_called() - mock_destroy.assert_called_once() - - -# TODO: refine test -def test_handle_video_stream() -> None: - frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) - video_packet = _make_video_packet(frame) - audio_packet = _make_audio_packet(numpy.zeros(1920, dtype = numpy.float32)) - - websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ]) - with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \ - patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \ - patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \ - patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ - patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \ - patch('facefusion.apis.stream_helper.encode_video_loop') as mock_loop, \ - patch('facefusion.apis.stream_helper.encode_audio_loop'), \ - patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc: - asyncio.run(handle_video_stream(websocket)) - websocket.accept.assert_called_once_with(subprotocol = 'proto') - websocket.send_text.assert_called_once_with('ready') - websocket.close.assert_called_once() - mock_rtc.init_peers.assert_called_once_with('session-1') - mock_rtc.delete_peers.assert_called_once_with('session-1') - _, _, loop_session_id, loop_resolution = mock_loop.call_args[0] - assert loop_session_id == 'session-1' - assert loop_resolution == (64, 64) - - websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.disconnect'} ]) - with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \ - patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \ - patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = None), \ - patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ - patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc: - asyncio.run(handle_video_stream(websocket)) - websocket.accept.assert_called_once() - websocket.send_text.assert_not_called() - mock_rtc.init_peers.assert_not_called() - - websocket = _make_handler_websocket([ {'type': 'websocket.receive', 'bytes': video_packet}, {'type': 'websocket.receive', 'bytes': audio_packet}, {'type': 'websocket.disconnect'} ]) - with patch('facefusion.apis.stream_helper.get_sec_websocket_protocol', return_value = 'proto'), \ - patch('facefusion.apis.stream_helper.extract_access_token', return_value = 'token'), \ - patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \ - patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ - patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \ - patch('facefusion.apis.stream_helper.encode_video_loop'), \ - patch('facefusion.apis.stream_helper.encode_audio_loop') as mock_audio_loop, \ - patch('facefusion.apis.stream_helper.rtc_store'): - asyncio.run(handle_video_stream(websocket)) - audio_queue = mock_audio_loop.call_args[0][1] - assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'