diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 538046ab..28be711d 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -2,6 +2,7 @@ import asyncio import queue # TODO: try deque import time from collections.abc import AsyncIterator +from functools import partial from typing import Optional, Tuple, cast, get_args import cv2 @@ -15,7 +16,7 @@ from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encod from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer from facefusion.streamer import process_vision_frame -from facefusion.types import PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame +from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame # TODO: refine this method @@ -24,10 +25,11 @@ async def handle_video_stream(websocket : WebSocket) -> None: access_token = extract_access_token(websocket.scope) session_id = session_manager.find_session_id(access_token) session_context.set_session_id(session_id) - stream_codec : VideoCodec = 'av1' + video_codec : VideoCodec = 'av1' + audio_codec : AudioCodec = 'opus' if websocket.query_params.get('codec') in get_args(VideoCodec): - stream_codec = cast(VideoCodec, websocket.query_params.get('codec')) + video_codec = cast(VideoCodec, websocket.query_params.get('codec')) await websocket.accept(subprotocol = subprotocol) @@ -51,12 +53,8 @@ async def handle_video_stream(websocket : WebSocket) -> None: event_loop = asyncio.get_running_loop() - if stream_codec == 'av1': - video_encode_task = event_loop.run_in_executor(None, run_aom_encode_loop, vision_frame_queue, session_id, resolution) - if stream_codec == 'vp8': - video_encode_task = event_loop.run_in_executor(None, run_vp8_encode_loop, vision_frame_queue, session_id, resolution) - - audio_encode_task = event_loop.run_in_executor(None, run_opus_encode_loop, audio_chunk_queue, session_id) + video_encode_task = event_loop.run_in_executor(None, encode_video_loop, video_codec, vision_frame_queue, session_id, resolution) + audio_encode_task = event_loop.run_in_executor(None, encode_audio_loop, audio_codec, audio_chunk_queue, session_id) await websocket.send_text('ready') async for frame_type, frame_buffer in stream_frames: @@ -138,21 +136,29 @@ def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAns return None -# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards -def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None: - aom_encoder = create_aom_encoder(frame_resolution, 4500, 8, 10) +def encode_video_loop(video_codec : VideoCodec, vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None: + create_encoder = partial(create_aom_encoder, 4500, 8, 10) + destroy_encoder = destroy_aom_encoder + encode_buffer = encode_aom_buffer + + if video_codec == 'vp8': + create_encoder = partial(create_vpx_encoder, 4500, 8, 16) + destroy_encoder = destroy_vpx_encoder # type:ignore[assignment] + encode_buffer = encode_vpx_buffer # type:ignore[assignment] + + encoder = create_encoder(frame_resolution) temp_resolution = frame_resolution timestamp = 0 vision_frame = vision_frame_queue.get() - while numpy.any(vision_frame) and aom_encoder: + while numpy.any(vision_frame) and encoder: output_vision_frame = process_vision_frame(vision_frame) output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) if output_resolution == temp_resolution: output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() - output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp) + output_frame_buffer = encode_buffer(encoder, output_frame_buffer, output_resolution, timestamp) rtc_peers = rtc_store.get_peers(session_id) if output_frame_buffer and rtc_peers: @@ -161,55 +167,17 @@ def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], timestamp += 1 vision_frame = vision_frame_queue.get() - #TODO: we are not using continue as control flow in the project - continue + else: + destroy_encoder(encoder) + temp_resolution = output_resolution + encoder = create_encoder(temp_resolution) + timestamp = 0 - destroy_aom_encoder(aom_encoder) - temp_resolution = output_resolution - aom_encoder = create_aom_encoder(temp_resolution, 4500, 8, 10) - timestamp = 0 - - if aom_encoder: - destroy_aom_encoder(aom_encoder) + if encoder: + destroy_encoder(encoder) -# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards -def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None: - vpx_encoder = create_vpx_encoder(frame_resolution, 4500, 8, 16) - temp_resolution = frame_resolution - timestamp = 0 - - vision_frame = vision_frame_queue.get() - - while numpy.any(vision_frame) and vpx_encoder: - output_vision_frame = process_vision_frame(vision_frame) - output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0]) - - if output_resolution == temp_resolution: - output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() - output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp) - rtc_peers = rtc_store.get_peers(session_id) - - if output_frame_buffer and rtc_peers: - video_timestamp = int(time.monotonic() * 90000) - rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp) - - timestamp += 1 - vision_frame = vision_frame_queue.get() - # TODO: we are not using continue as control flow in the project - continue - - destroy_vpx_encoder(vpx_encoder) - temp_resolution = output_resolution - vpx_encoder = create_vpx_encoder(temp_resolution, 4500, 8, 16) - timestamp = 0 - - if vpx_encoder: - destroy_vpx_encoder(vpx_encoder) - - -# TODO: switch to loop_encode_audio or encode_audio_loop ... pass audio_codec to follow standards -def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None: +def encode_audio_loop(audio_codec : AudioCodec, audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None: opus_encoder = create_opus_encoder(48000, 2) audio_timestamp = 0 diff --git a/facefusion/codecs/aom.py b/facefusion/codecs/aom.py index adbcd74d..ddb42608 100644 --- a/facefusion/codecs/aom.py +++ b/facefusion/codecs/aom.py @@ -6,7 +6,7 @@ from facefusion.libraries import aom as aom_module from facefusion.types import AomEncoder, BitRate, Resolution -def create_aom_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[AomEncoder]: +def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[AomEncoder]: aom_library = aom_module.create_static_library() if aom_library: diff --git a/facefusion/codecs/vpx.py b/facefusion/codecs/vpx.py index ebcacc56..2f76a5bf 100644 --- a/facefusion/codecs/vpx.py +++ b/facefusion/codecs/vpx.py @@ -6,7 +6,7 @@ from facefusion.libraries import vpx as vpx_module from facefusion.types import BitRate, Resolution, VpxEncoder -def create_vpx_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[VpxEncoder]: +def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[VpxEncoder]: vpx_library = vpx_module.create_static_library() if vpx_library: diff --git a/tests/test_codec_aom.py b/tests/test_codec_aom.py index f921a853..918f1a50 100644 --- a/tests/test_codec_aom.py +++ b/tests/test_codec_aom.py @@ -23,15 +23,15 @@ def before_all() -> None: def test_create_aom_encoder() -> None: - assert create_aom_encoder((320, 240), 1000, 8, 16) - assert create_aom_encoder((0, 0), 0, 0, 0) is None + assert create_aom_encoder(1000, 8, 16, (320, 240)) + assert create_aom_encoder(0, 0, 0, (0, 0)) is None def test_encode_aom_buffer() -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) - aom_encoder = create_aom_encoder(video_resolution, 1000, 1, 0) + aom_encoder = create_aom_encoder(1000, 1, 0, video_resolution) if is_linux() or is_windows(): assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31' @@ -41,7 +41,7 @@ def test_encode_aom_buffer() -> None: def test_destroy_aom_encoder() -> None: - aom_encoder = create_aom_encoder((320, 240), 1000, 8, 16) + aom_encoder = create_aom_encoder(1000, 8, 16, (320, 240)) with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock: destroy_aom_encoder(aom_encoder) diff --git a/tests/test_codec_vpx.py b/tests/test_codec_vpx.py index 9e297aee..93b25166 100644 --- a/tests/test_codec_vpx.py +++ b/tests/test_codec_vpx.py @@ -23,15 +23,15 @@ def before_all() -> None: def test_create_vpx_encoder() -> None: - assert create_vpx_encoder((320, 240), 1000, 8, 16) - assert create_vpx_encoder((0, 0), 0, 0, 0) is None + assert create_vpx_encoder(1000, 8, 16, (320, 240)) + assert create_vpx_encoder(0, 0, 0, (0, 0)) is None def test_encode_vpx_buffer() -> None: vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes() video_resolution = (vision_frame.shape[1], vision_frame.shape[0]) - vpx_encoder = create_vpx_encoder(video_resolution, 1000, 1, 0) + vpx_encoder = create_vpx_encoder(1000, 1, 0, video_resolution) if is_linux() or is_windows(): assert create_hash(encode_vpx_buffer(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f' @@ -41,7 +41,7 @@ def test_encode_vpx_buffer() -> None: def test_destroy_vpx_encoder() -> None: - vpx_encoder = create_vpx_encoder((320, 240), 1000, 8, 16) + vpx_encoder = create_vpx_encoder(1000, 8, 16, (320, 240)) with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock: destroy_vpx_encoder(vpx_encoder) diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py index 8b08ce3b..c2244db9 100644 --- a/tests/test_stream_helper.py +++ b/tests/test_stream_helper.py @@ -9,9 +9,9 @@ import pytest from numpy.typing import NDArray from starlette.websockets import WebSocketState -from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop +from facefusion.apis.stream_helper import encode_audio_loop, encode_video_loop, handle_video_stream from facefusion.hash_helper import create_hash -from facefusion.types import VisionFrame +from facefusion.types import VideoCodec, VisionFrame def _make_handler_websocket(events : list[Any]) -> MagicMock: @@ -35,7 +35,7 @@ def _make_audio_packet(samples : NDArray[Any]) -> bytes: @pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) -def test_run_video_encode_loop(video_codec : str) -> None: +def test_encode_video_loop(video_codec : VideoCodec) -> None: frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8) large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8) @@ -45,13 +45,11 @@ def test_run_video_encode_loop(video_codec : str) -> None: create_name = prefix + 'create_aom_encoder' encode_name = prefix + 'encode_aom_buffer' destroy_name = prefix + 'destroy_aom_encoder' - run_loop = run_aom_encode_loop if video_codec == 'vp8': create_name = prefix + 'create_vpx_encoder' encode_name = prefix + 'encode_vpx_buffer' destroy_name = prefix + 'destroy_vpx_encoder' - run_loop = run_vp8_encode_loop vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue() vision_frame_queue.put(frame) @@ -62,8 +60,8 @@ def test_run_video_encode_loop(video_codec : str) -> None: patch(destroy_name), \ patch(prefix + 'rtc_store') as mock_rtc_store, \ patch(prefix + 'rtc') as mock_rtc: - mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] - run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc_store.get_peers.return_value = [ MagicMock() ] + encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) mock_rtc.send_video_to_peers.assert_called_once() vision_frame_queue = queue.Queue() @@ -77,8 +75,8 @@ def test_run_video_encode_loop(video_codec : str) -> None: patch(destroy_name), \ patch(prefix + 'rtc_store') as mock_rtc_store, \ patch(prefix + 'rtc') as mock_rtc: - mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] - run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc_store.get_peers.return_value = [ MagicMock() ] + encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) assert mock_rtc.send_video_to_peers.call_count == 3 vision_frame_queue = queue.Queue() @@ -86,7 +84,7 @@ def test_run_video_encode_loop(video_codec : str) -> None: with patch(create_name, return_value = MagicMock()), \ patch(destroy_name) as mock_destroy, \ patch(prefix + 'rtc') as mock_rtc: - run_loop(vision_frame_queue, 'session-1', (64, 64)) + encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) mock_rtc.send_video_to_peers.assert_not_called() mock_destroy.assert_called_once() @@ -99,7 +97,7 @@ def test_run_video_encode_loop(video_codec : str) -> None: patch(destroy_name), \ patch(prefix + 'rtc_store'), \ patch(prefix + 'rtc') as mock_rtc: - run_loop(vision_frame_queue, 'session-1', (64, 64)) + encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) mock_rtc.send_video_to_peers.assert_not_called() vision_frame_queue = queue.Queue() @@ -111,8 +109,8 @@ def test_run_video_encode_loop(video_codec : str) -> None: patch(destroy_name) as mock_destroy, \ patch(prefix + 'rtc_store') as mock_rtc_store, \ patch(prefix + 'rtc') as mock_rtc: - mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] - run_loop(vision_frame_queue, 'session-1', (64, 64)) + mock_rtc_store.get_peers.return_value = [ MagicMock() ] + encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) assert mock_create.call_count == 2 assert mock_destroy.call_count == 2 mock_rtc.send_video_to_peers.assert_called_once() @@ -122,12 +120,12 @@ def test_run_video_encode_loop(video_codec : str) -> None: vision_frame_queue.put(None) with patch(create_name, return_value = None), \ patch(prefix + 'rtc') as mock_rtc: - run_loop(vision_frame_queue, 'session-1', (64, 64)) + encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64)) mock_rtc.send_video_to_peers.assert_not_called() # TODO: refine test -def test_run_opus_encode_loop() -> None: +def test_encode_audio_loop() -> None: audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes() audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue() @@ -138,8 +136,8 @@ def test_run_opus_encode_loop() -> None: patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \ patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] - run_opus_encode_loop(audio_chunk_queue, 'session-1') + mock_rtc_store.get_peers.return_value = [ MagicMock() ] + encode_audio_loop('opus', audio_chunk_queue, 'session-1') mock_rtc.send_audio_to_peers.assert_called_once() assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0 @@ -152,8 +150,8 @@ def test_run_opus_encode_loop() -> None: patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \ patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ] - run_opus_encode_loop(audio_chunk_queue, 'session-1') + mock_rtc_store.get_peers.return_value = [ MagicMock() ] + encode_audio_loop('opus', audio_chunk_queue, 'session-1') assert mock_rtc.send_audio_to_peers.call_count == 2 assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0 assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960 @@ -166,7 +164,7 @@ def test_run_opus_encode_loop() -> None: patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \ patch('facefusion.apis.stream_helper.rtc_store'), \ patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - run_opus_encode_loop(audio_chunk_queue, 'session-1') + encode_audio_loop('opus', audio_chunk_queue, 'session-1') mock_rtc.send_audio_to_peers.assert_not_called() audio_chunk_queue = queue.Queue() @@ -177,7 +175,7 @@ def test_run_opus_encode_loop() -> None: patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \ patch('facefusion.apis.stream_helper.rtc_store'), \ patch('facefusion.apis.stream_helper.rtc'): - run_opus_encode_loop(audio_chunk_queue, 'session-1') + encode_audio_loop('opus', audio_chunk_queue, 'session-1') mock_destroy.assert_called_once() audio_chunk_queue = queue.Queue() @@ -185,7 +183,7 @@ def test_run_opus_encode_loop() -> None: with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \ patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \ patch('facefusion.apis.stream_helper.rtc') as mock_rtc: - run_opus_encode_loop(audio_chunk_queue, 'session-1') + encode_audio_loop('opus', audio_chunk_queue, 'session-1') mock_rtc.send_audio_to_peers.assert_not_called() mock_destroy.assert_called_once() @@ -202,8 +200,8 @@ def test_handle_video_stream() -> None: patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \ patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \ - patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \ - patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \ + patch('facefusion.apis.stream_helper.encode_video_loop') as mock_loop, \ + patch('facefusion.apis.stream_helper.encode_audio_loop'), \ patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc: asyncio.run(handle_video_stream(websocket)) websocket.accept.assert_called_once_with(subprotocol = 'proto') @@ -211,7 +209,7 @@ def test_handle_video_stream() -> None: websocket.close.assert_called_once() mock_rtc.init_peers.assert_called_once_with('session-1') mock_rtc.delete_peers.assert_called_once_with('session-1') - _, loop_session_id, loop_resolution = mock_loop.call_args[0] + _, _, loop_session_id, loop_resolution = mock_loop.call_args[0] assert loop_session_id == 'session-1' assert loop_resolution == (64, 64) @@ -232,9 +230,9 @@ def test_handle_video_stream() -> None: patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \ patch('facefusion.apis.stream_helper.session_context.set_session_id'), \ patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \ - patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \ - patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \ + patch('facefusion.apis.stream_helper.encode_video_loop'), \ + patch('facefusion.apis.stream_helper.encode_audio_loop') as mock_audio_loop, \ patch('facefusion.apis.stream_helper.rtc_store'): asyncio.run(handle_video_stream(websocket)) - audio_queue = mock_audio_loop.call_args[0][0] + audio_queue = mock_audio_loop.call_args[0][1] assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'