diff --git a/facefusion/apis/stream_audio.py b/facefusion/apis/stream_audio.py index b4f4f010..54e32d7e 100644 --- a/facefusion/apis/stream_audio.py +++ b/facefusion/apis/stream_audio.py @@ -2,33 +2,33 @@ import ctypes import time from functools import partial from queue import Queue -from typing import Optional +from typing import Optional, Tuple import numpy from facefusion import rtc from facefusion.apis.stream_event import create_receive_event from facefusion.codecs import opus_decoder, opus_encoder -from facefusion.types import AudioCodec, AudioPack, OpusDecoder, RtcPeer, RtcPeerAudio +from facefusion.types import AudioCodec, AudioFrame, OpusDecoder, RtcPeer, RtcPeerAudio -def run_audio_encode_loop(rtc_peer : RtcPeer, audio_queue : Queue[AudioPack]) -> None: - temp_audio_frame, temp_audio_time = audio_queue.get() +def run_audio_encode_loop(rtc_peer : RtcPeer, audio_queue : Queue[Tuple[float, AudioFrame]]) -> None: + temp_audio_time, temp_audio_frame = audio_queue.get() audio_encoder = opus_encoder.create(48000, 2) while numpy.any(temp_audio_frame): - output_audio_buffer = opus_encoder.encode(audio_encoder, temp_audio_frame.tobytes(), 960) + audio_buffer = opus_encoder.encode(audio_encoder, temp_audio_frame.tobytes(), 960) - if output_audio_buffer: + if audio_buffer: audio_timestamp = int(temp_audio_time * 48000) - rtc.send_audio(rtc_peer, output_audio_buffer, audio_timestamp) + rtc.send_audio(rtc_peer, audio_buffer, audio_timestamp) - temp_audio_frame, temp_audio_time = audio_queue.get() + temp_audio_time, temp_audio_frame = audio_queue.get() opus_encoder.destroy(audio_encoder) -def receive_audio_frames(rtc_peer_audio : RtcPeerAudio, audio_queue : Queue[AudioPack]) -> None: +def receive_audio_frames(rtc_peer_audio : RtcPeerAudio, audio_queue : Queue[Tuple[float, AudioFrame]]) -> None: audio_track = rtc_peer_audio.get('receiver_track') audio_codec = rtc_peer_audio.get('codec') audio_decoder = create_audio_decoder(audio_codec) @@ -36,8 +36,9 @@ def receive_audio_frames(rtc_peer_audio : RtcPeerAudio, audio_queue : Queue[Audi audio_frame_handler = partial(handle_audio_frame, audio_codec, audio_decoder, audio_queue) receive_event = create_receive_event(audio_track, audio_frame_handler) receive_event.wait() + empty_audio_frame = numpy.empty(0) - audio_queue.put((empty_audio_frame, 0.0)) + audio_queue.put((0.0, empty_audio_frame)) destroy_audio_decoder(audio_codec, audio_decoder) @@ -58,10 +59,11 @@ def destroy_audio_decoder(audio_codec : AudioCodec, audio_decoder : OpusDecoder) opus_decoder.destroy(audio_decoder) -def handle_audio_frame(audio_codec : AudioCodec, audio_decoder : OpusDecoder, audio_queue : Queue[AudioPack], track : int, data : ctypes.c_void_p, size : int, info : ctypes.c_void_p, pointer : ctypes.c_void_p) -> None: +#todo: Alias Time for float +def handle_audio_frame(audio_codec : AudioCodec, audio_decoder : OpusDecoder, audio_queue : Queue[Tuple[float, AudioFrame]], track : int, data : ctypes.c_void_p, size : int, info : ctypes.c_void_p, pointer : ctypes.c_void_p) -> None: audio_buffer = ctypes.string_at(data, size) audio_frame = decode_audio_frame(audio_codec, audio_decoder, audio_buffer) if audio_frame: temp_audio_frame = numpy.frombuffer(audio_frame, dtype = numpy.float32) - audio_queue.put((temp_audio_frame, time.monotonic())) + audio_queue.put((time.monotonic(), temp_audio_frame)) diff --git a/facefusion/apis/stream_manager.py b/facefusion/apis/stream_manager.py index d4eb7b31..5f387da9 100644 --- a/facefusion/apis/stream_manager.py +++ b/facefusion/apis/stream_manager.py @@ -1,19 +1,20 @@ import ctypes import threading from collections.abc import AsyncIterator +from concurrent.futures import Future, ThreadPoolExecutor from queue import Queue -from typing import Optional +from typing import Optional, Tuple import cv2 import numpy from starlette.websockets import WebSocket -from facefusion import rtc, rtc_store, streamer +from facefusion import rtc, rtc_store, state_manager, streamer from facefusion.apis.stream_audio import receive_audio_frames, run_audio_encode_loop from facefusion.apis.stream_video import receive_video_frames, run_video_encode_loop from facefusion.audio import create_empty_audio_frame from facefusion.libraries import datachannel as datachannel_module -from facefusion.types import AudioCodec, AudioPack, PeerConnection, RtcPeer, RtcPeerAudio, SdpAnswer, SdpOffer, SessionId, VideoCodec, VideoPack, VisionFrame +from facefusion.types import AudioCodec, AudioFrame, PeerConnection, Resolution, RtcPeer, RtcPeerAudio, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame async def process_image(websocket : WebSocket) -> None: @@ -104,10 +105,13 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: - video_queue : Queue[VideoPack] = Queue(maxsize = 30) - audio_queue : Queue[AudioPack] = Queue(maxsize = 300) + execution_thread_count = state_manager.get_item('execution_thread_count') + #todo: is bytes, Resolution not a XXXPointer type + video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]] = Queue(maxsize = execution_thread_count) + audio_queue : Queue[Tuple[float, AudioFrame]] = Queue(maxsize = execution_thread_count * 10) + video_executor = ThreadPoolExecutor(max_workers = execution_thread_count) - video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer.get('video'), video_queue), daemon = True) + video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer.get('video'), video_queue, video_executor), daemon = True) video_encoder_thread = threading.Thread(target = run_video_encode_loop, args = (rtc_peer, video_queue), daemon = True) video_receiver_thread.start() video_encoder_thread.start() @@ -122,6 +126,7 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: video_receiver_thread.join() video_encoder_thread.join() + video_executor.shutdown(wait = True) rtc_store.delete_peers(session_id) diff --git a/facefusion/apis/stream_video.py b/facefusion/apis/stream_video.py index 278f9694..dfe2a6e6 100644 --- a/facefusion/apis/stream_video.py +++ b/facefusion/apis/stream_video.py @@ -1,101 +1,85 @@ import ctypes import time -from collections import deque -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from functools import partial from queue import Queue -from typing import Optional +from typing import Optional, Tuple import cv2 import numpy -from facefusion import rtc, state_manager, streamer +from facefusion import rtc, streamer from facefusion.apis.stream_event import create_receive_event from facefusion.audio import create_empty_audio_frame from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder -from facefusion.types import AomDecoder, AomEncoder, AomPointer, BitRate, Resolution, RtcPeer, RtcPeerVideo, VideoCodec, VideoPack, VisionFrame, VpxDecoder, VpxEncoder, VpxPointer +from facefusion.types import AomDecoder, AomEncoder, AomPointer, BitRate, Resolution, RtcPeer, RtcPeerVideo, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder, VpxPointer -def run_video_encode_loop(rtc_peer : RtcPeer, video_queue : Queue[VideoPack]) -> None: +def run_video_encode_loop(rtc_peer : RtcPeer, video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]]) -> None: video_codec = rtc_peer.get('video').get('codec') - temp_vision_frame, temp_video_time = video_queue.get() + video_time, video_future = video_queue.get() + video_buffer, video_resolution = video_future.result() - if numpy.any(temp_vision_frame): - temp_resolution : Resolution = (temp_vision_frame.shape[1], temp_vision_frame.shape[0]) + if video_buffer: + temp_resolution : Resolution = video_resolution temp_bitrate : BitRate = 8000 video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate) - previous_video_time = temp_video_time - #todo: fix typing once issue below is sorted out - temp_deque = deque() #type:ignore[var-annotated] - execution_thread_count = state_manager.get_item('execution_thread_count') + temp_video_time = video_time frame_index = 0 - with ThreadPoolExecutor(max_workers = execution_thread_count) as executor: - while numpy.any(temp_vision_frame) or temp_deque: + while video_buffer: + encode_start = time.monotonic() + sender_bitrate = calculate_sender_bitrate(rtc_peer, temp_bitrate) - if numpy.any(temp_vision_frame) and len(temp_deque) < execution_thread_count: - #todo: why does the deque contain a future and not just the data - temp_deque.append((executor.submit(process_video_frame, temp_vision_frame), temp_video_time)) - temp_vision_frame, temp_video_time = video_queue.get() + if video_resolution[0] - temp_resolution[0] or video_resolution[1] - temp_resolution[1]: + temp_resolution = video_resolution + update_video_encoder_resolution(video_codec, video_encoder, temp_resolution) - else: - output_future, output_video_time = temp_deque.popleft() - encode_start = time.monotonic() - output_vision_buffer, output_resolution = output_future.result() - peer_bitrate : BitRate = rtc_peer.get('sender_bitrate').value - video_encoder, temp_resolution, temp_bitrate, frame_index = adapt_video_encoder(video_codec, video_encoder, temp_resolution, temp_bitrate, output_resolution, peer_bitrate, frame_index) - output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index) + if sender_bitrate - temp_bitrate: + temp_bitrate = sender_bitrate + update_video_encoder_bitrate(video_codec, video_encoder, temp_bitrate) - if output_video_buffer: - rtc.send_video(rtc_peer, output_video_buffer, int(output_video_time * 90000)) + __video_buffer__ = encode_video_frame(video_codec, video_encoder, video_buffer, temp_resolution, frame_index) - encode_time = time.monotonic() - encode_start - frame_interval = output_video_time - previous_video_time - previous_video_time = output_video_time + if __video_buffer__: + video_timestamp = int(video_time * 90000) + rtc.send_video(rtc_peer, __video_buffer__, video_timestamp) - rtc.adapt_receiver_bitrate(rtc_peer, calculate_receiver_bitrate(rtc_peer, encode_time, frame_interval)) - frame_index += 1 + encode_time = time.monotonic() - encode_start + frame_interval = video_time - temp_video_time + temp_video_time = video_time + + receiver_bitrate = calculate_receiver_bitrate(rtc_peer, encode_time, frame_interval) + rtc.adapt_receiver_bitrate(rtc_peer, receiver_bitrate) + frame_index += 1 + + video_time, video_future = video_queue.get() + video_buffer, video_resolution = video_future.result() destroy_video_encoder(video_codec, video_encoder) rtc.clear_bitrate(rtc_peer) -def receive_video_frames(rtc_peer_video : RtcPeerVideo, video_queue : Queue[VideoPack]) -> None: +def receive_video_frames(rtc_peer_video : RtcPeerVideo, video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]], video_executor : ThreadPoolExecutor) -> None: video_track = rtc_peer_video.get('receiver_track') video_codec = rtc_peer_video.get('codec') video_decoder = create_video_decoder(video_codec) - video_frame_handler = partial(handle_video_frame, video_codec, video_decoder, video_queue) + video_frame_handler = partial(handle_video_frame, video_codec, video_decoder, video_queue, video_executor) receive_event = create_receive_event(video_track, video_frame_handler) receive_event.wait() - empty_vision_frame = numpy.empty(0) - video_queue.put((empty_vision_frame, 0.0)) + + empty_future : Future[Tuple[bytes, Resolution]] = Future() + empty_future.set_result((bytes(), (0, 0))) + video_queue.put((0.0, empty_future)) destroy_video_decoder(video_codec, video_decoder) -def process_video_frame(vision_frame : VisionFrame) -> tuple[bytes, Resolution]: - output_vision_frame = streamer.process_frame(create_empty_audio_frame(), vision_frame) +def process_video_frame(input_vision_frame : VisionFrame) -> Tuple[bytes, Resolution]: + output_vision_frame = streamer.process_frame(create_empty_audio_frame(), input_vision_frame) output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) - output_vision_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() - return output_vision_buffer, output_resolution - - -def adapt_video_encoder(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, resolution : Resolution, bitrate : BitRate, output_resolution : Resolution, peer_bitrate : BitRate, frame_index : int) -> tuple[VpxEncoder | AomEncoder, Resolution, BitRate, int]: - if output_resolution[0] - resolution[0] or output_resolution[1] - resolution[1]: - destroy_video_encoder(video_codec, video_encoder) - resolution = output_resolution - video_encoder = create_video_encoder(video_codec, resolution, bitrate) - frame_index = 0 - - if peer_bitrate and peer_bitrate - bitrate: - bitrate = peer_bitrate - - if not update_video_encoder_bitrate(video_codec, video_encoder, bitrate): - destroy_video_encoder(video_codec, video_encoder) - video_encoder = create_video_encoder(video_codec, resolution, bitrate) - frame_index = 0 - - return video_encoder, resolution, bitrate, frame_index + output_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() + return output_buffer, output_resolution def calculate_receiver_bitrate(rtc_peer : RtcPeer, encode_time : float, frame_interval : float) -> BitRate: @@ -111,6 +95,18 @@ def calculate_receiver_bitrate(rtc_peer : RtcPeer, encode_time : float, frame_in return bitrate +#todo: does not feel final as this is an clamp and not calculate +def calculate_sender_bitrate(rtc_peer : RtcPeer, bitrate : BitRate) -> BitRate: + min_bitrate : BitRate = 500 + max_bitrate : BitRate = 8000 + peer_bitrate : BitRate = rtc_peer.get('sender_bitrate').value + + if peer_bitrate > 0: + bitrate = max(min_bitrate, min(max_bitrate, peer_bitrate)) + + return bitrate + + def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, input_buffer : bytes) -> Optional[VisionFrame]: if video_codec == 'av1': aom_pointer = aom_decoder.decode(video_decoder, input_buffer) @@ -179,6 +175,16 @@ def destroy_video_encoder(video_codec : VideoCodec, video_encoder : VpxEncoder | vpx_encoder.destroy(video_encoder) +def update_video_encoder_resolution(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, frame_resolution : Resolution) -> bool: + if video_codec == 'av1': + return aom_encoder.update_resolution(video_encoder, frame_resolution) + + if video_codec in [ 'vp8', 'vp9' ]: + return vpx_encoder.update_resolution(video_encoder, frame_resolution) + + return False + + def update_video_encoder_bitrate(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, bitrate : BitRate) -> bool: if video_codec == 'av1': return aom_encoder.update_bitrate(video_encoder, bitrate) @@ -189,12 +195,11 @@ def update_video_encoder_bitrate(video_codec : VideoCodec, video_encoder : VpxEn return False -def handle_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, video_queue : Queue[VideoPack], track : int, data : ctypes.c_void_p, size : int, info : ctypes.c_void_p, pointer : ctypes.c_void_p) -> None: +#todo: we can remove the dead args or pass audio buffer +def handle_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]], video_executor : ThreadPoolExecutor, track : int, data : ctypes.c_void_p, size : int, info : ctypes.c_void_p, pointer : ctypes.c_void_p) -> None: video_buffer = ctypes.string_at(data, size) vision_frame = decode_video_frame(video_codec, video_decoder, video_buffer) - if numpy.any(vision_frame): - if video_queue.full(): - video_queue.get_nowait() - - video_queue.put((vision_frame, time.monotonic())) + if numpy.any(vision_frame) and video_queue.qsize() < video_queue.maxsize: + video_future = video_executor.submit(process_video_frame, vision_frame) + video_queue.put((time.monotonic(), video_future)) diff --git a/facefusion/codecs/aom_encoder.py b/facefusion/codecs/aom_encoder.py index b31d7056..e696bf66 100644 --- a/facefusion/codecs/aom_encoder.py +++ b/facefusion/codecs/aom_encoder.py @@ -66,6 +66,17 @@ def collect(aom_encoder : AomEncoder) -> bytes: return bytes().join(output_parts) +def update_resolution(aom_encoder : AomEncoder, frame_resolution : Resolution) -> bool: + aom_library = aom_module.create_static_library() + + if aom_library: + struct.pack_into('I', aom_encoder, 128 + 12, frame_resolution[0]) + struct.pack_into('I', aom_encoder, 128 + 16, frame_resolution[1]) + return aom_library.aom_codec_enc_config_set(aom_encoder, ctypes.cast(ctypes.addressof(aom_encoder) + 128, ctypes.c_void_p)) == 0 + + return False + + def update_bitrate(aom_encoder : AomEncoder, bitrate : BitRate) -> bool: aom_library = aom_module.create_static_library() diff --git a/facefusion/codecs/vpx_encoder.py b/facefusion/codecs/vpx_encoder.py index c095bbca..ef3da620 100644 --- a/facefusion/codecs/vpx_encoder.py +++ b/facefusion/codecs/vpx_encoder.py @@ -76,6 +76,17 @@ def collect(vpx_encoder : VpxEncoder) -> bytes: return bytes().join(output_parts) +def update_resolution(vpx_encoder : VpxEncoder, frame_resolution : Resolution) -> bool: + vpx_library = vpx_module.create_static_library() + + if vpx_library: + struct.pack_into('I', vpx_encoder, 64 + 12, frame_resolution[0]) + struct.pack_into('I', vpx_encoder, 64 + 16, frame_resolution[1]) + return vpx_library.vpx_codec_enc_config_set(vpx_encoder, ctypes.cast(ctypes.addressof(vpx_encoder) + 64, ctypes.c_void_p)) == 0 + + return False + + def update_bitrate(vpx_encoder : VpxEncoder, bitrate : BitRate) -> bool: vpx_library = vpx_module.create_static_library() diff --git a/facefusion/types.py b/facefusion/types.py index 5f942a38..0296ef93 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -320,9 +320,6 @@ RtcPeer = TypedDict('RtcPeer', }) RtcStore : TypeAlias = Dict[SessionId, List[RtcPeer]] -VideoPack : TypeAlias = tuple[VisionFrame, float] -AudioPack : TypeAlias = tuple[AudioFrame, float] - SdpAudioMedia = TypedDict('SdpAudioMedia', { 'codec': AudioCodec, diff --git a/tests/test_api_stream_audio.py b/tests/test_api_stream_audio.py index 9048b6b4..dc20293f 100644 --- a/tests/test_api_stream_audio.py +++ b/tests/test_api_stream_audio.py @@ -2,6 +2,7 @@ import ctypes import threading from functools import partial from queue import Queue +from typing import Tuple from unittest.mock import MagicMock, patch import numpy @@ -13,7 +14,7 @@ from facefusion.download import conditional_download from facefusion.ffmpeg import read_audio_buffer from facefusion.hash_helper import create_hash from facefusion.libraries import datachannel as datachannel_module, opus as opus_module -from facefusion.types import AudioCodec, AudioPack, FrameHandler, RtcPeer, RtcPeerAudio +from facefusion.types import AudioCodec, AudioFrame, FrameHandler, RtcPeer, RtcPeerAudio from .assert_helper import get_test_example_file, get_test_examples_directory @@ -57,9 +58,9 @@ def test_run_audio_encode_loop() -> None: 'receiver_bitrate': ctypes.c_uint(0) } - audio_queue : Queue[AudioPack] = Queue(maxsize = 300) + audio_queue : Queue[Tuple[float, AudioFrame]] = Queue(maxsize = 300) - audio_queue.put((audio_frame, 0.100)) + audio_queue.put((0.100, audio_frame)) encoder_mock = MagicMock() encoder_mock.encode.return_value = bytes([ 1 ] * 32) @@ -68,7 +69,7 @@ def test_run_audio_encode_loop() -> None: with patch('facefusion.apis.stream_audio.rtc.send_audio') as send_audio_mock: audio_loop_thread = threading.Thread(target = run_audio_encode_loop, args = (rtc_peer, audio_queue), daemon = True) audio_loop_thread.start() - audio_queue.put((numpy.empty(0), 0.0)) + audio_queue.put((0.0, numpy.empty(0))) audio_loop_thread.join(timeout = 5.0) assert encoder_mock.encode.called is True @@ -79,7 +80,7 @@ def test_run_audio_encode_loop() -> None: def test_receive_audio_frames(audio_codec : AudioCodec) -> None: audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) audio_frame = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 - audio_queue : Queue[AudioPack] = Queue(maxsize = 300) + audio_queue : Queue[Tuple[float, AudioFrame]] = Queue(maxsize = 300) datachannel_mock = MagicMock() ready_event = threading.Event() @@ -100,20 +101,20 @@ def test_receive_audio_frames(audio_codec : AudioCodec) -> None: datachannel_mock.rtcSetClosedCallback.call_args[0][1](0, None) audio_receiver_thread.join(timeout = 5.0) - buffer_frame, _ = audio_queue.get_nowait() + _, temp_audio_frame = audio_queue.get_nowait() - assert create_hash(buffer_frame.tobytes()) == create_hash(audio_frame.tobytes()) + assert create_hash(temp_audio_frame.tobytes()) == create_hash(audio_frame.tobytes()) def test_handle_audio_frame() -> None: audio_buffer = read_audio_buffer(get_test_example_file('source.mp3'), 48000, 16, 2) audio_frame = numpy.frombuffer(audio_buffer, dtype = numpy.int16).astype(numpy.float32) / 32768.0 audio_decoder_mock = MagicMock() - audio_queue : Queue[AudioPack] = Queue(maxsize = 300) + audio_queue : Queue[Tuple[float, AudioFrame]] = Queue(maxsize = 300) with patch('facefusion.apis.stream_audio.decode_audio_frame', return_value = audio_frame.tobytes()): handle_audio_frame('opus', audio_decoder_mock, audio_queue, 0, ctypes.c_void_p(), 1, ctypes.c_void_p(), ctypes.c_void_p()) - buffer_frame, _ = audio_queue.get_nowait() + _, temp_audio_frame = audio_queue.get_nowait() - assert create_hash(buffer_frame.tobytes()) == create_hash(audio_frame.tobytes()) + assert create_hash(temp_audio_frame.tobytes()) == create_hash(audio_frame.tobytes()) diff --git a/tests/test_api_stream_manager.py b/tests/test_api_stream_manager.py index 424f4d50..6e1bd346 100644 --- a/tests/test_api_stream_manager.py +++ b/tests/test_api_stream_manager.py @@ -17,6 +17,7 @@ from .assert_helper import get_test_example_file, get_test_examples_directory @pytest.fixture(scope = 'module', autouse = True) def before_all() -> None: state_manager.init_item('download_providers', [ 'github', 'huggingface' ]) + state_manager.init_item('execution_thread_count', 8) state_manager.init_item('processors', []) datachannel_module.pre_check() diff --git a/tests/test_api_stream_video.py b/tests/test_api_stream_video.py index 4fbedc64..28eea0d3 100644 --- a/tests/test_api_stream_video.py +++ b/tests/test_api_stream_video.py @@ -1,8 +1,10 @@ import ctypes import struct import threading +from concurrent.futures import Future, ThreadPoolExecutor from functools import partial from queue import Queue +from typing import Tuple from unittest.mock import MagicMock, patch import cv2 @@ -10,13 +12,13 @@ import numpy import pytest from facefusion import rtc, rtc_store, state_manager -from facefusion.apis.stream_video import create_video_decoder, create_video_encoder, decode_video_frame, destroy_video_decoder, destroy_video_encoder, encode_video_frame, handle_video_frame, receive_video_frames, run_video_encode_loop, update_video_encoder_bitrate +from facefusion.apis.stream_video import create_video_decoder, create_video_encoder, decode_video_frame, destroy_video_decoder, destroy_video_encoder, encode_video_frame, handle_video_frame, process_video_frame, receive_video_frames, run_video_encode_loop, update_video_encoder_bitrate, update_video_encoder_resolution from facefusion.codecs import aom_encoder, vpx_encoder from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download from facefusion.hash_helper import create_hash from facefusion.libraries import aom as aom_module, datachannel as datachannel_module, vpx as vpx_module -from facefusion.types import FrameHandler, RtcPeer, RtcPeerVideo, VideoCodec, VideoPack +from facefusion.types import FrameHandler, Resolution, RtcPeer, RtcPeerVideo, VideoCodec from facefusion.vision import read_video_frame from .assert_helper import get_test_example_file, get_test_examples_directory @@ -65,16 +67,18 @@ def test_run_video_encode_loop(video_codec : VideoCodec, payload_type : int) -> 'receiver_bitrate': ctypes.c_uint(8000) } - video_queue : Queue[VideoPack] = Queue(maxsize = 30) + video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]] = Queue(maxsize = 30) - video_queue.put((video_frame, 0.1)) + with ThreadPoolExecutor(max_workers = 1) as executor: + video_queue.put((0.1, executor.submit(process_video_frame, video_frame))) - with patch('facefusion.apis.stream_video.rtc.send_video') as send_video_mock: - encode_loop_thread = threading.Thread(target = run_video_encode_loop, args = (rtc_peer, video_queue), daemon = True) - encode_loop_thread.start() - empty_vision_frame = numpy.empty(0) - video_queue.put((empty_vision_frame, 0.0)) - encode_loop_thread.join(timeout = 5.0) + with patch('facefusion.apis.stream_video.rtc.send_video') as send_video_mock: + encode_loop_thread = threading.Thread(target = run_video_encode_loop, args = (rtc_peer, video_queue), daemon = True) + encode_loop_thread.start() + empty_future : Future[Tuple[bytes, Resolution]] = Future() + empty_future.set_result((bytes(), (0, 0))) + video_queue.put((0.0, empty_future)) + encode_loop_thread.join(timeout = 5.0) assert send_video_mock.called @@ -95,34 +99,37 @@ def test_run_video_encode_loop(video_codec : VideoCodec, payload_type : int) -> @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8', 'vp9' ]) def test_receive_video_frames(video_codec : VideoCodec) -> None: video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) - video_queue : Queue[VideoPack] = Queue(maxsize = 30) + video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]] = Queue(maxsize = 30) datachannel_mock = MagicMock() ready_event = threading.Event() datachannel_mock.rtcSetClosedCallback.side_effect = partial(set_ready_event, ready_event) - with patch('facefusion.libraries.datachannel.create_static_library', return_value = datachannel_mock): - with patch('facefusion.apis.stream_video.decode_video_frame', return_value = video_frame): - rtc_peer_video : RtcPeerVideo =\ - { - 'sender_track': 0, - 'receiver_track': 0, - 'codec': video_codec - } - video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer_video, video_queue), daemon = True) - video_receiver_thread.start() - ready_event.wait(timeout = 5.0) - datachannel_mock.rtcSetFrameCallback.call_args[0][1](0, bytes([ 0 ]), 1, None, None) - datachannel_mock.rtcSetClosedCallback.call_args[0][1](0, None) - video_receiver_thread.join(timeout = 5.0) + with ThreadPoolExecutor(max_workers = 1) as executor: + with patch('facefusion.libraries.datachannel.create_static_library', return_value = datachannel_mock): + with patch('facefusion.apis.stream_video.decode_video_frame', return_value = video_frame): + with patch('facefusion.apis.stream_video.process_video_frame', return_value = (video_frame.tobytes(), (426, 226))): + rtc_peer_video : RtcPeerVideo =\ + { + 'sender_track': 0, + 'receiver_track': 0, + 'codec': video_codec + } + video_receiver_thread = threading.Thread(target = receive_video_frames, args = (rtc_peer_video, video_queue, executor), daemon = True) + video_receiver_thread.start() + ready_event.wait(timeout = 5.0) + datachannel_mock.rtcSetFrameCallback.call_args[0][1](0, bytes([ 0 ]), 1, None, None) + datachannel_mock.rtcSetClosedCallback.call_args[0][1](0, None) + video_receiver_thread.join(timeout = 5.0) + _, video_future = video_queue.get_nowait() - vision_frame, _ = video_queue.get_nowait() + video_buffer, _ = video_future.result() if is_linux() or is_windows(): - assert create_hash(vision_frame.tobytes()) == 'a17439db' + assert create_hash(video_buffer) == 'a17439db' if is_macos(): - assert create_hash(vision_frame.tobytes()) == '38d00e2a' + assert create_hash(video_buffer) == '38d00e2a' @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8', 'vp9' ]) @@ -197,6 +204,33 @@ def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None: assert vpx_encoder.encode(video_encoder, input_buffer, (426, 226), 1) == bytes() +@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8', 'vp9' ]) +def test_update_video_encoder_resolution(video_codec : VideoCodec) -> None: + video_encoder = create_video_encoder(video_codec, (426, 226), 4000) + + if video_codec == 'av1': + assert struct.unpack_from('I', video_encoder, 128 + 12)[0] == 426 + + if video_codec == 'vp8': + assert struct.unpack_from('I', video_encoder, 64 + 12)[0] == 426 + + if video_codec == 'vp9': + assert struct.unpack_from('I', video_encoder, 64 + 12)[0] == 426 + + assert update_video_encoder_resolution(video_codec, video_encoder, (320, 180)) + + if video_codec == 'av1': + assert struct.unpack_from('I', video_encoder, 128 + 12)[0] == 320 + + if video_codec == 'vp8': + assert struct.unpack_from('I', video_encoder, 64 + 12)[0] == 320 + + if video_codec == 'vp9': + assert struct.unpack_from('I', video_encoder, 64 + 12)[0] == 320 + + destroy_video_encoder(video_codec, video_encoder) + + @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8', 'vp9' ]) def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None: video_encoder = create_video_encoder(video_codec, (426, 226), 4000) @@ -228,15 +262,18 @@ def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None: def test_handle_video_frame(video_codec : VideoCodec) -> None: video_frame = read_video_frame(get_test_example_file('target-240p.mp4')) video_decoder = create_video_decoder(video_codec) - video_queue : Queue[VideoPack] = Queue(maxsize = 30) + video_queue : Queue[Tuple[float, Future[Tuple[bytes, Resolution]]]] = Queue(maxsize = 30) - with patch('facefusion.apis.stream_video.decode_video_frame', return_value = video_frame): - handle_video_frame(video_codec, video_decoder, video_queue, 0, ctypes.c_void_p(), 1, ctypes.c_void_p(), ctypes.c_void_p()) + with ThreadPoolExecutor(max_workers = 1) as executor: + with patch('facefusion.apis.stream_video.decode_video_frame', return_value = video_frame): + with patch('facefusion.apis.stream_video.process_video_frame', return_value = (video_frame.tobytes(), (426, 226))): + handle_video_frame(video_codec, video_decoder, video_queue, executor, 0, ctypes.c_void_p(), 1, ctypes.c_void_p(), ctypes.c_void_p()) + _, video_future = video_queue.get_nowait() - vision_frame, _ = video_queue.get_nowait() + video_buffer, _ = video_future.result() if is_linux() or is_windows(): - assert create_hash(vision_frame.tobytes()) == 'a17439db' + assert create_hash(video_buffer) == 'a17439db' if is_macos(): - assert create_hash(vision_frame.tobytes()) == '38d00e2a' + assert create_hash(video_buffer) == '38d00e2a'