From 775985645e3c8b5e2bdf7c5668af21f6b6ddcef8 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Fri, 5 Jun 2026 00:18:57 +0200 Subject: [PATCH] Push based receive with queue (#1146) * move to push based receive * move to push based receive, fix mocks * fix tests * add todos * remove asyncio * remove asyncio * resolve todos * move to queue without events * prevent debug spam * concurrent stream inference stream_video.py: pipeline face-swap inference across execution_thread_count workers (ThreadPoolExecutor + bounded in-flight deque, ordered encode) to keep the GPU busy during encode Co-Authored-By: Claude Opus 4.8 * add todos * add todos * add missing state --------- Co-authored-by: Claude Opus 4.8 --- facefusion/apis/stream_audio.py | 62 ++++++------------ facefusion/apis/stream_event.py | 21 +++++-- facefusion/apis/stream_manager.py | 38 ++++++----- facefusion/apis/stream_video.py | 97 ++++++++++++++--------------- facefusion/libraries/datachannel.py | 10 +-- facefusion/types.py | 2 + tests/test_api_stream_audio.py | 64 +++++++++---------- tests/test_api_stream_manager.py | 3 +- tests/test_api_stream_video.py | 96 +++++++++++++--------------- tests/test_codec_aom_decoder.py | 4 +- tests/test_codec_aom_encoder.py | 2 +- tests/test_codec_opus_decoder.py | 2 +- tests/test_codec_opus_encoder.py | 2 +- tests/test_codec_vpx_decoder.py | 2 +- tests/test_codec_vpx_encoder.py | 2 +- 15 files changed, 193 insertions(+), 214 deletions(-) diff --git a/facefusion/apis/stream_audio.py b/facefusion/apis/stream_audio.py index c90ae156..b4f4f010 100644 --- a/facefusion/apis/stream_audio.py +++ b/facefusion/apis/stream_audio.py @@ -1,22 +1,19 @@ import ctypes -import threading import time -from collections import deque +from functools import partial +from queue import Queue from typing import Optional import numpy from facefusion import rtc -from facefusion.apis.stream_event import create_event +from facefusion.apis.stream_event import create_receive_event from facefusion.codecs import opus_decoder, opus_encoder -from facefusion.libraries import datachannel as datachannel_module from facefusion.types import AudioCodec, AudioPack, OpusDecoder, RtcPeer, RtcPeerAudio -def run_audio_encode_loop(rtc_peer : RtcPeer, audio_deque : deque[AudioPack], audio_event : threading.Event) -> None: - audio_event.wait() - audio_event.clear() - temp_audio_frame, temp_audio_time = audio_deque.popleft() +def run_audio_encode_loop(rtc_peer : RtcPeer, audio_queue : Queue[AudioPack]) -> None: + temp_audio_frame, temp_audio_time = audio_queue.get() audio_encoder = opus_encoder.create(48000, 2) while numpy.any(temp_audio_frame): @@ -26,64 +23,45 @@ def run_audio_encode_loop(rtc_peer : RtcPeer, audio_deque : deque[AudioPack], au audio_timestamp = int(temp_audio_time * 48000) rtc.send_audio(rtc_peer, output_audio_buffer, audio_timestamp) - if len(audio_deque) == 0: - audio_event.wait() - audio_event.clear() - - temp_audio_frame, temp_audio_time = audio_deque.popleft() + temp_audio_frame, temp_audio_time = audio_queue.get() opus_encoder.destroy(audio_encoder) -def receive_audio_frames(rtc_peer_audio : RtcPeerAudio, audio_deque : deque[AudioPack], audio_event : threading.Event) -> None: +def receive_audio_frames(rtc_peer_audio : RtcPeerAudio, audio_queue : Queue[AudioPack]) -> None: audio_track = rtc_peer_audio.get('receiver_track') audio_codec = rtc_peer_audio.get('codec') - datachannel_library = datachannel_module.create_static_library() audio_decoder = create_audio_decoder(audio_codec) - receive_buffer = ctypes.create_string_buffer(8 * 1024) - available_event = create_event(audio_track, datachannel_library) - receive_status_code = -3 - - while receive_status_code == 0 or receive_status_code == -3: - buffer_size = ctypes.c_int(8 * 1024) - receive_status_code = datachannel_library.rtcReceiveMessage(audio_track, receive_buffer, ctypes.byref(buffer_size)) - - if receive_status_code == 0 and buffer_size.value > 0: - audio_buffer = receive_buffer.raw[:buffer_size.value] - fill_audio_deque(audio_codec, audio_decoder, audio_buffer, audio_deque, audio_event) - - if receive_status_code == -3: - available_event.wait() - available_event.clear() + audio_frame_handler = partial(handle_audio_frame, audio_codec, audio_decoder, audio_queue) + receive_event = create_receive_event(audio_track, audio_frame_handler) + receive_event.wait() empty_audio_frame = numpy.empty(0) - audio_deque.append((empty_audio_frame, 0.0)) - audio_event.set() + audio_queue.put((empty_audio_frame, 0.0)) destroy_audio_decoder(audio_codec, audio_decoder) -def fill_audio_deque(audio_codec : AudioCodec, audio_decoder : OpusDecoder, audio_buffer : bytes, audio_deque : deque[AudioPack], audio_event : threading.Event) -> None: - audio_frame = decode_audio_frame(audio_codec, audio_decoder, audio_buffer) - - if audio_frame: - audio_deque.append((numpy.frombuffer(audio_frame, dtype = numpy.float32), time.monotonic())) - audio_event.set() - - def decode_audio_frame(audio_codec : AudioCodec, audio_decoder : OpusDecoder, input_buffer : bytes) -> Optional[bytes]: if audio_codec == 'opus': return opus_decoder.decode(audio_decoder, input_buffer, 960, 2) - return None def create_audio_decoder(audio_codec : AudioCodec) -> Optional[OpusDecoder]: if audio_codec == 'opus': return opus_decoder.create(48000, 2) - return None def destroy_audio_decoder(audio_codec : AudioCodec, audio_decoder : OpusDecoder) -> None: if audio_codec == 'opus': opus_decoder.destroy(audio_decoder) + + +def handle_audio_frame(audio_codec : AudioCodec, audio_decoder : OpusDecoder, audio_queue : Queue[AudioPack], track : int, data : ctypes.c_void_p, size : int, info : ctypes.c_void_p, pointer : ctypes.c_void_p) -> None: + audio_buffer = ctypes.string_at(data, size) + audio_frame = decode_audio_frame(audio_codec, audio_decoder, audio_buffer) + + if audio_frame: + temp_audio_frame = numpy.frombuffer(audio_frame, dtype = numpy.float32) + audio_queue.put((temp_audio_frame, time.monotonic())) diff --git a/facefusion/apis/stream_event.py b/facefusion/apis/stream_event.py index 17ec2c64..aafb1ac6 100644 --- a/facefusion/apis/stream_event.py +++ b/facefusion/apis/stream_event.py @@ -2,13 +2,22 @@ import ctypes import threading from functools import partial +from facefusion.libraries import datachannel as datachannel_module +from facefusion.types import FrameHandler -def create_event(track : int, datachannel_library : ctypes.CDLL) -> threading.Event: - available_event = threading.Event() - available_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)(partial(dispatch_event, available_event)) - datachannel_library.rtcSetAvailableCallback(track, available_callback) - available_event.callback = available_callback # type: ignore[attr-defined] - return available_event + +def create_receive_event(track : int, frame_handler : FrameHandler) -> threading.Event: + datachannel_library = datachannel_module.create_static_library() + receive_event = threading.Event() + + frame_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p)(frame_handler) + close_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)(partial(dispatch_event, receive_event)) + datachannel_library.rtcSetFrameCallback(track, frame_callback) + datachannel_library.rtcSetClosedCallback(track, close_callback) + receive_event.frame_callback = frame_callback # type: ignore[attr-defined] + receive_event.close_callback = close_callback # type: ignore[attr-defined] + + return receive_event def dispatch_event(event : threading.Event, track : int, pointer : ctypes.c_void_p) -> None: diff --git a/facefusion/apis/stream_manager.py b/facefusion/apis/stream_manager.py index d7861273..460f5305 100644 --- a/facefusion/apis/stream_manager.py +++ b/facefusion/apis/stream_manager.py @@ -1,8 +1,7 @@ -import asyncio import ctypes import threading -from collections import deque from collections.abc import AsyncIterator +from queue import Queue from typing import Optional import cv2 @@ -66,9 +65,9 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA audio_sender_track = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, audio_payload_type) rtc.set_remote_description(peer_connection, sdp_offer) - local_sdp = rtc.create_sdp_answer(peer_connection) + sdp_answer = rtc.create_sdp_answer(peer_connection) - if local_sdp: + if sdp_answer: rtc_peer : RtcPeer =\ { 'peer_connection': peer_connection, @@ -92,29 +91,34 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA rtc_store.init_peers(session_id) rtc_store.get_peers(session_id).append(rtc_peer) - threading.Thread(target = asyncio.run, args = (run_peer_loop(session_id, rtc_peer),), daemon = True).start() - return local_sdp + threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True).start() + + return sdp_answer datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) return None -async def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: - video_deque : deque[VideoPack] = deque(maxlen = 1) - audio_deque : deque[AudioPack] = deque(maxlen = 10) - video_event = threading.Event() +def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: + video_queue : Queue[VideoPack] = Queue(maxsize = 30) + audio_queue : Queue[AudioPack] = Queue(maxsize = 300) - video_receiver_thread = asyncio.to_thread(receive_video_frames, rtc_peer.get('video'), video_deque, video_event) - video_encoder_thread = asyncio.to_thread(run_video_encode_loop, rtc_peer, video_deque, video_event) - coroutines = [ video_receiver_thread, video_encoder_thread ] + video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer.get('video'), video_queue), daemon = True) + video_encoder_thread = threading.Thread(target = run_video_encode_loop, args = (rtc_peer, video_queue), daemon = True) + video_receiver_thread.start() + video_encoder_thread.start() if rtc_peer.get('audio'): - audio_event = threading.Event() - coroutines.append(asyncio.to_thread(receive_audio_frames, rtc_peer.get('audio'), audio_deque, audio_event)) - coroutines.append(asyncio.to_thread(run_audio_encode_loop, rtc_peer, audio_deque, audio_event)) + audio_receiver_thread = threading.Thread(target = receive_audio_frames, args = (rtc_peer.get('audio'), audio_queue), daemon = True) + audio_encoder_thread = threading.Thread(target = run_audio_encode_loop, args = (rtc_peer, audio_queue), daemon = True) + audio_receiver_thread.start() + audio_encoder_thread.start() + audio_receiver_thread.join() + audio_encoder_thread.join() - await asyncio.gather(*coroutines) + video_receiver_thread.join() + video_encoder_thread.join() rtc_store.delete_peers(session_id) diff --git a/facefusion/apis/stream_video.py b/facefusion/apis/stream_video.py index d70ed421..dedc5cae 100644 --- a/facefusion/apis/stream_video.py +++ b/facefusion/apis/stream_video.py @@ -1,93 +1,77 @@ import ctypes -import threading import time from collections import deque +from concurrent.futures import Future, ThreadPoolExecutor +from functools import partial +from queue import Queue from typing import Optional import cv2 import numpy -from facefusion import rtc, streamer -from facefusion.apis.stream_event import create_event +from facefusion import rtc, state_manager, streamer +from facefusion.apis.stream_event import create_receive_event from facefusion.audio import create_empty_audio_frame from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder -from facefusion.libraries import datachannel as datachannel_module from facefusion.types import AomDecoder, AomEncoder, AomPointer, BitRate, Resolution, RtcPeer, RtcPeerVideo, VideoCodec, VideoPack, VisionFrame, VpxDecoder, VpxEncoder, VpxPointer -def run_video_encode_loop(rtc_peer : RtcPeer, video_deque : deque[VideoPack], video_event : threading.Event) -> None: - video_event.wait() - video_event.clear() +def run_video_encode_loop(rtc_peer : RtcPeer, video_queue : Queue[VideoPack]) -> None: video_codec = rtc_peer.get('video').get('codec') - temp_vision_frame, temp_video_time = video_deque.popleft() + temp_vision_frame, temp_video_time = video_queue.get() if numpy.any(temp_vision_frame): temp_resolution : Resolution = (temp_vision_frame.shape[1], temp_vision_frame.shape[0]) temp_bitrate : BitRate = 8000 video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate) previous_video_time = temp_video_time + #todo: find less complex type here + temp_deque : deque[tuple[Future[tuple[bytes, Resolution]], float]] = deque() + execution_thread_count = state_manager.get_item('execution_thread_count') frame_index = 0 - while numpy.any(temp_vision_frame): - encode_start = time.monotonic() - output_vision_buffer, output_resolution = process_video_frame(temp_vision_frame) - peer_bitrate : BitRate = rtc_peer.get('sender_bitrate').value - video_encoder, temp_resolution, temp_bitrate, frame_index = adapt_video_encoder(video_codec, video_encoder, temp_resolution, temp_bitrate, output_resolution, peer_bitrate, frame_index) - output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index) + with ThreadPoolExecutor(max_workers = execution_thread_count) as executor: + while numpy.any(temp_vision_frame) or temp_deque: - if output_video_buffer: - rtc.send_video(rtc_peer, output_video_buffer, int(temp_video_time * 90000)) + if numpy.any(temp_vision_frame) and len(temp_deque) < execution_thread_count: + temp_deque.append((executor.submit(process_video_frame, temp_vision_frame), temp_video_time)) + temp_vision_frame, temp_video_time = video_queue.get() - encode_time = time.monotonic() - encode_start - frame_interval = temp_video_time - previous_video_time - previous_video_time = temp_video_time + else: + output_future, output_video_time = temp_deque.popleft() + encode_start = time.monotonic() + output_vision_buffer, output_resolution = output_future.result() + peer_bitrate : BitRate = rtc_peer.get('sender_bitrate').value + video_encoder, temp_resolution, temp_bitrate, frame_index = adapt_video_encoder(video_codec, video_encoder, temp_resolution, temp_bitrate, output_resolution, peer_bitrate, frame_index) + output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index) - rtc.adapt_receiver_bitrate(rtc_peer, calculate_receiver_bitrate(rtc_peer, encode_time, frame_interval)) + if output_video_buffer: + rtc.send_video(rtc_peer, output_video_buffer, int(output_video_time * 90000)) - frame_index += 1 - video_event.wait() - video_event.clear() - temp_vision_frame, temp_video_time = video_deque.popleft() + encode_time = time.monotonic() - encode_start + frame_interval = output_video_time - previous_video_time + previous_video_time = output_video_time + + rtc.adapt_receiver_bitrate(rtc_peer, calculate_receiver_bitrate(rtc_peer, encode_time, frame_interval)) + frame_index += 1 destroy_video_encoder(video_codec, video_encoder) rtc.clear_bitrate(rtc_peer) -def receive_video_frames(rtc_peer_video : RtcPeerVideo, video_deque : deque[VideoPack], video_event : threading.Event) -> None: +def receive_video_frames(rtc_peer_video : RtcPeerVideo, video_queue : Queue[VideoPack]) -> None: video_track = rtc_peer_video.get('receiver_track') video_codec = rtc_peer_video.get('codec') - datachannel_library = datachannel_module.create_static_library() video_decoder = create_video_decoder(video_codec) - receive_buffer = ctypes.create_string_buffer(512 * 1024) - available_event = create_event(video_track, datachannel_library) - receive_status_code = -3 - - while receive_status_code == 0 or receive_status_code == -3: - buffer_size = ctypes.c_int(512 * 1024) - receive_status_code = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size)) - - if receive_status_code == 0 and buffer_size.value > 0: - video_buffer = receive_buffer.raw[:buffer_size.value] - fill_video_deque(video_codec, video_decoder, video_buffer, video_deque, video_event) - - if receive_status_code == -3: - available_event.wait() - available_event.clear() + video_frame_handler = partial(handle_video_frame, video_codec, video_decoder, video_queue) + receive_event = create_receive_event(video_track, video_frame_handler) + receive_event.wait() empty_vision_frame = numpy.empty(0) - video_deque.append((empty_vision_frame, 0.0)) - video_event.set() + video_queue.put((empty_vision_frame, 0.0)) destroy_video_decoder(video_codec, video_decoder) -def fill_video_deque(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, video_buffer : bytes, video_deque : deque[VideoPack], video_event : threading.Event) -> None: - vision_frame = decode_video_frame(video_codec, video_decoder, video_buffer) - - if numpy.any(vision_frame): - video_deque.append((vision_frame, time.monotonic())) - video_event.set() - - def process_video_frame(vision_frame : VisionFrame) -> tuple[bytes, Resolution]: output_vision_frame = streamer.process_frame(create_empty_audio_frame(), vision_frame) output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) @@ -202,3 +186,14 @@ def update_video_encoder_bitrate(video_codec : VideoCodec, video_encoder : VpxEn return vpx_encoder.update_bitrate(video_encoder, bitrate) return False + + +def handle_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, video_queue : Queue[VideoPack], track : int, data : ctypes.c_void_p, size : int, info : ctypes.c_void_p, pointer : ctypes.c_void_p) -> None: + video_buffer = ctypes.string_at(data, size) + vision_frame = decode_video_frame(video_codec, video_decoder, video_buffer) + + if numpy.any(vision_frame): + if video_queue.full(): + video_queue.get_nowait() + + video_queue.put((vision_frame, time.monotonic())) diff --git a/facefusion/libraries/datachannel.py b/facefusion/libraries/datachannel.py index 803395fc..743fd827 100644 --- a/facefusion/libraries/datachannel.py +++ b/facefusion/libraries/datachannel.py @@ -166,7 +166,7 @@ def create_static_library() -> Optional[ctypes.CDLL]: def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcInitLogger.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) ] library.rtcInitLogger.restype = None - library.rtcInitLogger(5, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)(0)) + library.rtcInitLogger(2, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)(0)) library.rtcCreatePeerConnection.restype = ctypes.c_int @@ -215,8 +215,8 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcChainRtcpReceivingSession.argtypes = [ ctypes.c_int ] library.rtcChainRtcpReceivingSession.restype = ctypes.c_int - library.rtcReceiveMessage.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int) ] - library.rtcReceiveMessage.restype = ctypes.c_int + library.rtcSetFrameCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p) ] + library.rtcSetFrameCallback.restype = ctypes.c_int library.rtcSetUserPointer.argtypes = [ ctypes.c_int, ctypes.c_void_p ] library.rtcSetUserPointer.restype = None @@ -227,8 +227,8 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcRequestBitrate.argtypes = [ ctypes.c_int, ctypes.c_uint ] library.rtcRequestBitrate.restype = ctypes.c_int - library.rtcSetAvailableCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p) ] - library.rtcSetAvailableCallback.restype = ctypes.c_int + library.rtcSetClosedCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p) ] + library.rtcSetClosedCallback.restype = ctypes.c_int return library diff --git a/facefusion/types.py b/facefusion/types.py index ea2c1fe8..6d4949cd 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -97,6 +97,8 @@ Resolution : TypeAlias = Tuple[int, int] AudioCodec : TypeAlias = Literal['opus'] VideoCodec : TypeAlias = Literal['av1', 'vp8'] +FrameHandler : TypeAlias = Callable[..., None] + AomEncoder : TypeAlias = ctypes.Array[ctypes.c_char] AomDecoder : TypeAlias = ctypes.Array[ctypes.c_char] OpusEncoder : TypeAlias = ctypes.c_void_p diff --git a/tests/test_api_stream_audio.py b/tests/test_api_stream_audio.py index e651da9e..91950bba 100644 --- a/tests/test_api_stream_audio.py +++ b/tests/test_api_stream_audio.py @@ -1,13 +1,14 @@ import ctypes import threading -from collections import deque +from functools import partial +from queue import Queue from unittest.mock import MagicMock, patch import numpy import pytest from facefusion import rtc, rtc_store, state_manager -from facefusion.apis.stream_audio import fill_audio_deque, receive_audio_frames, run_audio_encode_loop +from facefusion.apis.stream_audio import handle_audio_frame, receive_audio_frames, run_audio_encode_loop from facefusion.download import conditional_download from facefusion.ffmpeg import read_audio_buffer from facefusion.hash_helper import create_hash @@ -52,54 +53,36 @@ def test_run_audio_encode_loop() -> None: 'receiver_bitrate': ctypes.c_uint(0) } - audio_deque : deque[AudioPack] = deque() - audio_event = threading.Event() + audio_queue : Queue[AudioPack] = Queue(maxsize = 300) - audio_deque.append((audio_frame, 0.100)) - audio_event.set() + audio_queue.put((audio_frame, 0.100)) encoder_mock = MagicMock() encoder_mock.encode.return_value = bytes([ 1 ] * 32) with patch('facefusion.apis.stream_audio.opus_encoder.encode', encoder_mock.encode): with patch('facefusion.apis.stream_audio.rtc.send_audio') as send_audio_mock: - audio_loop_thread = threading.Thread(target = run_audio_encode_loop, args = (rtc_peer, audio_deque, audio_event), daemon = True) + audio_loop_thread = threading.Thread(target = run_audio_encode_loop, args = (rtc_peer, audio_queue), daemon = True) audio_loop_thread.start() - audio_deque.append((numpy.empty(0), 0.0)) - audio_event.set() + audio_queue.put((numpy.empty(0), 0.0)) audio_loop_thread.join(timeout = 5.0) assert encoder_mock.encode.called is True assert send_audio_mock.called is True -def test_fill_audio_deque() -> None: - audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) - audio_frame = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 - audio_decoder_mock = MagicMock() - audio_deque : deque[AudioPack] = deque() - audio_event = threading.Event() - - with patch('facefusion.apis.stream_audio.decode_audio_frame', return_value = audio_frame.tobytes()): - fill_audio_deque('opus', audio_decoder_mock, audio_frame.tobytes(), audio_deque, audio_event) - - buffer_frame, _ = audio_deque.popleft() - - assert audio_event.is_set() - assert create_hash(buffer_frame.tobytes()) == create_hash(audio_frame.tobytes()) - - @pytest.mark.parametrize('audio_codec', [ 'opus' ]) def test_receive_audio_frames(audio_codec : AudioCodec) -> None: audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) audio_frame = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 - audio_deque : deque[AudioPack] = deque() - audio_event = threading.Event() + audio_queue : Queue[AudioPack] = Queue(maxsize = 300) - datachannel_library_mock = MagicMock() - datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ] + datachannel_mock = MagicMock() + ready_event = threading.Event() + # todo: lambda not allowed + datachannel_mock.rtcSetClosedCallback.side_effect = partial(lambda event, *args: event.set(), ready_event) - with patch('facefusion.apis.stream_audio.datachannel_module.create_static_library', return_value = datachannel_library_mock): + with patch('facefusion.libraries.datachannel.create_static_library', return_value = datachannel_mock): with patch('facefusion.apis.stream_audio.decode_audio_frame', return_value = audio_frame.tobytes()): rtc_peer_audio : RtcPeerAudio =\ { @@ -107,10 +90,27 @@ def test_receive_audio_frames(audio_codec : AudioCodec) -> None: 'receiver_track': 0, 'codec': audio_codec } - audio_receiver_thread = threading.Thread(target = receive_audio_frames, args = (rtc_peer_audio, audio_deque, audio_event), daemon = True) + audio_receiver_thread = threading.Thread(target = receive_audio_frames, args = (rtc_peer_audio, audio_queue), daemon = True) audio_receiver_thread.start() + ready_event.wait(timeout = 5.0) + datachannel_mock.rtcSetFrameCallback.call_args[0][1](0, bytes([ 0 ]), 1, None, None) + datachannel_mock.rtcSetClosedCallback.call_args[0][1](0, None) audio_receiver_thread.join(timeout = 5.0) - buffer_frame, _ = audio_deque.popleft() + buffer_frame, _ = audio_queue.get_nowait() + + assert create_hash(buffer_frame.tobytes()) == create_hash(audio_frame.tobytes()) + + +def test_handle_audio_frame() -> None: + audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) + audio_frame = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 + audio_decoder_mock = MagicMock() + audio_queue : Queue[AudioPack] = Queue(maxsize = 300) + + with patch('facefusion.apis.stream_audio.decode_audio_frame', return_value = audio_frame.tobytes()): + handle_audio_frame('opus', audio_decoder_mock, audio_queue, 0, ctypes.c_void_p(), 1, ctypes.c_void_p(), ctypes.c_void_p()) + + buffer_frame, _ = audio_queue.get_nowait() assert create_hash(buffer_frame.tobytes()) == create_hash(audio_frame.tobytes()) diff --git a/tests/test_api_stream_manager.py b/tests/test_api_stream_manager.py index cfb46afa..424f4d50 100644 --- a/tests/test_api_stream_manager.py +++ b/tests/test_api_stream_manager.py @@ -1,4 +1,3 @@ -import asyncio import ctypes import threading from unittest.mock import AsyncMock, patch @@ -138,7 +137,7 @@ def test_run_peer_loop(video_codec : VideoCodec, payload_type : int, session_id with patch('facefusion.apis.stream_manager.receive_video_frames'): with patch('facefusion.apis.stream_manager.run_video_encode_loop'): - thread = threading.Thread(target = asyncio.run, args = (run_peer_loop(session_id, rtc_peer),), daemon = True) + thread = threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True) thread.start() thread.join(timeout = 5.0) diff --git a/tests/test_api_stream_video.py b/tests/test_api_stream_video.py index f5376b5d..832f2501 100644 --- a/tests/test_api_stream_video.py +++ b/tests/test_api_stream_video.py @@ -1,7 +1,8 @@ import ctypes import struct import threading -from collections import deque +from functools import partial +from queue import Queue from unittest.mock import MagicMock, patch import cv2 @@ -9,7 +10,7 @@ import numpy import pytest from facefusion import rtc, rtc_store, state_manager -from facefusion.apis.stream_video import create_video_decoder, create_video_encoder, decode_video_frame, destroy_video_decoder, destroy_video_encoder, encode_video_frame, fill_video_deque, receive_video_frames, run_video_encode_loop, update_video_encoder_bitrate +from facefusion.apis.stream_video import create_video_decoder, create_video_encoder, decode_video_frame, destroy_video_decoder, destroy_video_encoder, encode_video_frame, handle_video_frame, receive_video_frames, run_video_encode_loop, update_video_encoder_bitrate from facefusion.codecs import aom_encoder, vpx_encoder from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download @@ -23,6 +24,7 @@ from .assert_helper import get_test_example_file, get_test_examples_directory @pytest.fixture(scope = 'module', autouse = True) def before_all() -> None: state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) + state_manager.init_item('execution_thread_count', 8) state_manager.init_item('processors', []) aom_module.pre_check() @@ -59,18 +61,15 @@ def test_run_video_encode_loop(video_codec : VideoCodec, payload_type : int) -> 'receiver_bitrate': ctypes.c_uint(8000) } - video_deque : deque[VideoPack] = deque() - video_event = threading.Event() + video_queue : Queue[VideoPack] = Queue(maxsize = 30) - video_deque.append((video_frame, 0.1)) - video_event.set() + video_queue.put((video_frame, 0.1)) with patch('facefusion.apis.stream_video.rtc.send_video') as send_video_mock: - encode_loop_thread = threading.Thread(target = run_video_encode_loop, args = (rtc_peer, video_deque, video_event), daemon = True) + encode_loop_thread = threading.Thread(target = run_video_encode_loop, args = (rtc_peer, video_queue), daemon = True) encode_loop_thread.start() empty_vision_frame = numpy.empty(0) - video_deque.append((empty_vision_frame, 0.0)) - video_event.set() + video_queue.put((empty_vision_frame, 0.0)) encode_loop_thread.join(timeout = 5.0) assert send_video_mock.called @@ -89,13 +88,14 @@ def test_run_video_encode_loop(video_codec : VideoCodec, payload_type : int) -> @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) def test_receive_video_frames(video_codec : VideoCodec) -> None: video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - video_deque : deque[VideoPack] = deque() - video_event = threading.Event() + video_queue : Queue[VideoPack] = Queue(maxsize = 30) - datachannel_library_mock = MagicMock() - datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ] + datachannel_mock = MagicMock() + ready_event = threading.Event() + #todo: lambda not allowed + datachannel_mock.rtcSetClosedCallback.side_effect = partial(lambda event, *args: event.set(), ready_event) - with patch('facefusion.apis.stream_video.datachannel_module.create_static_library', return_value = datachannel_library_mock): + with patch('facefusion.libraries.datachannel.create_static_library', return_value = datachannel_mock): with patch('facefusion.apis.stream_video.decode_video_frame', return_value = video_frame): rtc_peer_video : RtcPeerVideo =\ { @@ -103,11 +103,14 @@ def test_receive_video_frames(video_codec : VideoCodec) -> None: 'receiver_track': 0, 'codec': video_codec } - video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer_video, video_deque, video_event), daemon = True) + video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer_video, video_queue), daemon = True) video_receiver_thread.start() + ready_event.wait(timeout = 5.0) + datachannel_mock.rtcSetFrameCallback.call_args[0][1](0, bytes([ 0 ]), 1, None, None) + datachannel_mock.rtcSetClosedCallback.call_args[0][1](0, None) video_receiver_thread.join(timeout = 5.0) - vision_frame, _ = video_deque.popleft() + vision_frame, _ = video_queue.get_nowait() if is_linux() or is_windows(): assert create_hash(vision_frame.tobytes()) == 'a17439db' @@ -116,37 +119,6 @@ def test_receive_video_frames(video_codec : VideoCodec) -> None: assert create_hash(vision_frame.tobytes()) == '38d00e2a' -@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) -def test_fill_video_deque(video_codec : VideoCodec) -> None: - video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - input_buffer = cv2.cvtColor(video_frame, cv2.COLOR_BGR2YUV_I420).tobytes() - video_encoder = create_video_encoder(video_codec, (426, 226), 1000) - video_decoder = create_video_decoder(video_codec) - encode_buffer = encode_video_frame(video_codec, video_encoder, input_buffer, (426, 226), 0) - video_deque : deque[VideoPack] = deque() - video_event = threading.Event() - - fill_video_deque(video_codec, video_decoder, encode_buffer, video_deque, video_event) - - vision_frame, _ = video_deque.popleft() - - assert video_event.is_set() - - if is_linux() or is_windows(): - if video_codec == 'av1': - assert create_hash(vision_frame.tobytes()) == 'b5b6486d' - - if video_codec == 'vp8': - assert create_hash(vision_frame.tobytes()) == '99ef2c25' - - if is_macos(): - if video_codec == 'av1': - assert create_hash(vision_frame.tobytes()) == '74e9926f' - - if video_codec == 'vp8': - assert create_hash(vision_frame.tobytes()) == 'ff3ecb43' - - @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) def test_encode_and_decode_video_frame(video_codec : VideoCodec) -> None: video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) @@ -165,7 +137,7 @@ def test_encode_and_decode_video_frame(video_codec : VideoCodec) -> None: if is_macos(): if video_codec == 'av1': - assert create_hash(decode_buffer) == '74e9926f' + assert create_hash(decode_buffer) == 'eafd1fab' if video_codec == 'vp8': assert create_hash(decode_buffer) == 'ff3ecb43' @@ -178,10 +150,11 @@ def test_create_and_destroy_video_decoder(video_codec : VideoCodec) -> None: video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) input_buffer = cv2.cvtColor(video_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + # todo: this head be hash based checks before, now the codnitions seem pointless if video_codec == 'av1': video_encoder = aom_encoder.create((426, 226), 1000, 1, 0) encode_buffer = aom_encoder.encode(video_encoder, input_buffer, (426, 226), 0) - + # todo: this head be hash based checks before, now the codnitions seem pointless if video_codec == 'vp8': video_encoder = vpx_encoder.create((426, 226), 1000, 1, 0) encode_buffer = vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 0) @@ -201,17 +174,18 @@ def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None: input_buffer = cv2.cvtColor(video_frame, cv2.COLOR_BGR2YUV_I420).tobytes() video_encoder = create_video_encoder(video_codec, (426, 226), 4000) + # todo: this head be hash based checks before, now the codnitions seem pointless if video_codec == 'av1': assert aom_encoder.encode(video_encoder, input_buffer, (426, 226), 0) - + # todo: this head be hash based checks before, now the codnitions seem pointless if video_codec == 'vp8': assert vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 0) destroy_video_encoder(video_codec, video_encoder) - + # todo: this head be hash based checks before, now the codnitions seem pointless if video_codec == 'av1': assert aom_encoder.encode(video_encoder, input_buffer, (426, 226), 1) == bytes() - + # todo: this head be hash based checks before, now the codnitions seem pointless if video_codec == 'vp8': assert vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 1) == bytes() @@ -235,3 +209,21 @@ def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None: assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 6000 destroy_video_encoder(video_codec, video_encoder) + + +@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) +def test_handle_video_frame(video_codec : VideoCodec) -> None: + video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + video_decoder = create_video_decoder(video_codec) + video_queue : Queue[VideoPack] = Queue(maxsize = 30) + + with patch('facefusion.apis.stream_video.decode_video_frame', return_value = video_frame): + handle_video_frame(video_codec, video_decoder, video_queue, 0, ctypes.c_void_p(), 1, ctypes.c_void_p(), ctypes.c_void_p()) + + vision_frame, _ = video_queue.get_nowait() + + if is_linux() or is_windows(): + assert create_hash(vision_frame.tobytes()) == 'a17439db' + + if is_macos(): + assert create_hash(vision_frame.tobytes()) == '38d00e2a' diff --git a/tests/test_codec_aom_decoder.py b/tests/test_codec_aom_decoder.py index 884ea663..1de82c04 100644 --- a/tests/test_codec_aom_decoder.py +++ b/tests/test_codec_aom_decoder.py @@ -26,7 +26,7 @@ def before_all() -> None: def test_create() -> None: assert create(1) - with patch('facefusion.codecs.aom_decoder.aom_module.create_static_library', return_value = None): + with patch('facefusion.libraries.aom.create_static_library', return_value = None): assert create(1) is None @@ -42,7 +42,7 @@ def test_decode() -> None: assert create_hash(decode(aom_decoder, encode_buffer).get('buffer')) == 'e3c0ebd8' if is_macos(): - assert create_hash(decode(aom_decoder, encode_buffer).get('buffer')) == '0a0ab3d0' + assert create_hash(decode(aom_decoder, encode_buffer).get('buffer')) == 'c8c6fdaa' def test_destroy() -> None: diff --git a/tests/test_codec_aom_encoder.py b/tests/test_codec_aom_encoder.py index 589d9792..811a432d 100644 --- a/tests/test_codec_aom_encoder.py +++ b/tests/test_codec_aom_encoder.py @@ -25,7 +25,7 @@ def before_all() -> None: def test_create() -> None: assert create((320, 240), 1000, 8, 16) - with patch('facefusion.codecs.aom_encoder.aom_module.create_static_library', return_value = None): + with patch('facefusion.libraries.aom.create_static_library', return_value = None): assert create((320, 240), 1000, 8, 16) is None diff --git a/tests/test_codec_opus_decoder.py b/tests/test_codec_opus_decoder.py index b809c7f7..b8eceaf5 100644 --- a/tests/test_codec_opus_decoder.py +++ b/tests/test_codec_opus_decoder.py @@ -26,7 +26,7 @@ def before_all() -> None: def test_create() -> None: assert create(48000, 2) - with patch('facefusion.codecs.opus_decoder.opus_module.create_static_library', return_value = None): + with patch('facefusion.libraries.opus.create_static_library', return_value = None): assert create(48000, 2) is None diff --git a/tests/test_codec_opus_encoder.py b/tests/test_codec_opus_encoder.py index 798f11f5..9450439a 100644 --- a/tests/test_codec_opus_encoder.py +++ b/tests/test_codec_opus_encoder.py @@ -25,7 +25,7 @@ def before_all() -> None: def test_create() -> None: assert create(48000, 2) - with patch('facefusion.codecs.opus_encoder.opus_module.create_static_library', return_value = None): + with patch('facefusion.libraries.opus.create_static_library', return_value = None): assert create(48000, 2) is None diff --git a/tests/test_codec_vpx_decoder.py b/tests/test_codec_vpx_decoder.py index ed2aca81..e3f27390 100644 --- a/tests/test_codec_vpx_decoder.py +++ b/tests/test_codec_vpx_decoder.py @@ -26,7 +26,7 @@ def before_all() -> None: def test_create() -> None: assert create(1) - with patch('facefusion.codecs.vpx_decoder.vpx_module.create_static_library', return_value = None): + with patch('facefusion.libraries.vpx.create_static_library', return_value = None): assert create(1) is None diff --git a/tests/test_codec_vpx_encoder.py b/tests/test_codec_vpx_encoder.py index 334aef5a..289adedd 100644 --- a/tests/test_codec_vpx_encoder.py +++ b/tests/test_codec_vpx_encoder.py @@ -25,7 +25,7 @@ def before_all() -> None: def test_create() -> None: assert create((320, 240), 1000, 8, 16) - with patch('facefusion.codecs.vpx_encoder.vpx_module.create_static_library', return_value = None): + with patch('facefusion.libraries.vpx.create_static_library', return_value = None): assert create((320, 240), 1000, 8, 16) is None