Combine encode loop methods (#1119)

* combine run_aom_encode_loop and run_vpx_encode_loop to encode_video_loop

* run_opus_encode_loop -> encode_audio_loop

* use else instead of continue

* rename to video_codec
This commit is contained in:
Harisreedhar
2026-05-16 23:29:36 +05:30
committed by GitHub
parent dd1ded1408
commit c48c238f88
6 changed files with 64 additions and 98 deletions
+28 -60
View File
@@ -2,6 +2,7 @@ import asyncio
import queue # TODO: try deque
import time
from collections.abc import AsyncIterator
from functools import partial
from typing import Optional, Tuple, cast, get_args
import cv2
@@ -15,7 +16,7 @@ from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encod
from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer
from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer
from facefusion.streamer import process_vision_frame
from facefusion.types import PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
# TODO: refine this method
@@ -24,10 +25,11 @@ async def handle_video_stream(websocket : WebSocket) -> None:
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
stream_codec : VideoCodec = 'av1'
video_codec : VideoCodec = 'av1'
audio_codec : AudioCodec = 'opus'
if websocket.query_params.get('codec') in get_args(VideoCodec):
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
video_codec = cast(VideoCodec, websocket.query_params.get('codec'))
await websocket.accept(subprotocol = subprotocol)
@@ -51,12 +53,8 @@ async def handle_video_stream(websocket : WebSocket) -> None:
event_loop = asyncio.get_running_loop()
if stream_codec == 'av1':
video_encode_task = event_loop.run_in_executor(None, run_aom_encode_loop, vision_frame_queue, session_id, resolution)
if stream_codec == 'vp8':
video_encode_task = event_loop.run_in_executor(None, run_vp8_encode_loop, vision_frame_queue, session_id, resolution)
audio_encode_task = event_loop.run_in_executor(None, run_opus_encode_loop, audio_chunk_queue, session_id)
video_encode_task = event_loop.run_in_executor(None, encode_video_loop, video_codec, vision_frame_queue, session_id, resolution)
audio_encode_task = event_loop.run_in_executor(None, encode_audio_loop, audio_codec, audio_chunk_queue, session_id)
await websocket.send_text('ready')
async for frame_type, frame_buffer in stream_frames:
@@ -138,21 +136,29 @@ def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAns
return None
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
aom_encoder = create_aom_encoder(frame_resolution, 4500, 8, 10)
def encode_video_loop(video_codec : VideoCodec, vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
create_encoder = partial(create_aom_encoder, 4500, 8, 10)
destroy_encoder = destroy_aom_encoder
encode_buffer = encode_aom_buffer
if video_codec == 'vp8':
create_encoder = partial(create_vpx_encoder, 4500, 8, 16)
destroy_encoder = destroy_vpx_encoder # type:ignore[assignment]
encode_buffer = encode_vpx_buffer # type:ignore[assignment]
encoder = create_encoder(frame_resolution)
temp_resolution = frame_resolution
timestamp = 0
vision_frame = vision_frame_queue.get()
while numpy.any(vision_frame) and aom_encoder:
while numpy.any(vision_frame) and encoder:
output_vision_frame = process_vision_frame(vision_frame)
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
if output_resolution == temp_resolution:
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp)
output_frame_buffer = encode_buffer(encoder, output_frame_buffer, output_resolution, timestamp)
rtc_peers = rtc_store.get_peers(session_id)
if output_frame_buffer and rtc_peers:
@@ -161,55 +167,17 @@ def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]],
timestamp += 1
vision_frame = vision_frame_queue.get()
#TODO: we are not using continue as control flow in the project
continue
else:
destroy_encoder(encoder)
temp_resolution = output_resolution
encoder = create_encoder(temp_resolution)
timestamp = 0
destroy_aom_encoder(aom_encoder)
temp_resolution = output_resolution
aom_encoder = create_aom_encoder(temp_resolution, 4500, 8, 10)
timestamp = 0
if aom_encoder:
destroy_aom_encoder(aom_encoder)
if encoder:
destroy_encoder(encoder)
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
vpx_encoder = create_vpx_encoder(frame_resolution, 4500, 8, 16)
temp_resolution = frame_resolution
timestamp = 0
vision_frame = vision_frame_queue.get()
while numpy.any(vision_frame) and vpx_encoder:
output_vision_frame = process_vision_frame(vision_frame)
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
if output_resolution == temp_resolution:
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp)
rtc_peers = rtc_store.get_peers(session_id)
if output_frame_buffer and rtc_peers:
video_timestamp = int(time.monotonic() * 90000)
rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp)
timestamp += 1
vision_frame = vision_frame_queue.get()
# TODO: we are not using continue as control flow in the project
continue
destroy_vpx_encoder(vpx_encoder)
temp_resolution = output_resolution
vpx_encoder = create_vpx_encoder(temp_resolution, 4500, 8, 16)
timestamp = 0
if vpx_encoder:
destroy_vpx_encoder(vpx_encoder)
# TODO: switch to loop_encode_audio or encode_audio_loop ... pass audio_codec to follow standards
def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
def encode_audio_loop(audio_codec : AudioCodec, audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
opus_encoder = create_opus_encoder(48000, 2)
audio_timestamp = 0
+1 -1
View File
@@ -6,7 +6,7 @@ from facefusion.libraries import aom as aom_module
from facefusion.types import AomEncoder, BitRate, Resolution
def create_aom_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[AomEncoder]:
def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[AomEncoder]:
aom_library = aom_module.create_static_library()
if aom_library:
+1 -1
View File
@@ -6,7 +6,7 @@ from facefusion.libraries import vpx as vpx_module
from facefusion.types import BitRate, Resolution, VpxEncoder
def create_vpx_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[VpxEncoder]:
def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[VpxEncoder]:
vpx_library = vpx_module.create_static_library()
if vpx_library:
+4 -4
View File
@@ -23,15 +23,15 @@ def before_all() -> None:
def test_create_aom_encoder() -> None:
assert create_aom_encoder((320, 240), 1000, 8, 16)
assert create_aom_encoder((0, 0), 0, 0, 0) is None
assert create_aom_encoder(1000, 8, 16, (320, 240))
assert create_aom_encoder(0, 0, 0, (0, 0)) is None
def test_encode_aom_buffer() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
aom_encoder = create_aom_encoder(video_resolution, 1000, 1, 0)
aom_encoder = create_aom_encoder(1000, 1, 0, video_resolution)
if is_linux() or is_windows():
assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31'
@@ -41,7 +41,7 @@ def test_encode_aom_buffer() -> None:
def test_destroy_aom_encoder() -> None:
aom_encoder = create_aom_encoder((320, 240), 1000, 8, 16)
aom_encoder = create_aom_encoder(1000, 8, 16, (320, 240))
with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock:
destroy_aom_encoder(aom_encoder)
+4 -4
View File
@@ -23,15 +23,15 @@ def before_all() -> None:
def test_create_vpx_encoder() -> None:
assert create_vpx_encoder((320, 240), 1000, 8, 16)
assert create_vpx_encoder((0, 0), 0, 0, 0) is None
assert create_vpx_encoder(1000, 8, 16, (320, 240))
assert create_vpx_encoder(0, 0, 0, (0, 0)) is None
def test_encode_vpx_buffer() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
vpx_encoder = create_vpx_encoder(video_resolution, 1000, 1, 0)
vpx_encoder = create_vpx_encoder(1000, 1, 0, video_resolution)
if is_linux() or is_windows():
assert create_hash(encode_vpx_buffer(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f'
@@ -41,7 +41,7 @@ def test_encode_vpx_buffer() -> None:
def test_destroy_vpx_encoder() -> None:
vpx_encoder = create_vpx_encoder((320, 240), 1000, 8, 16)
vpx_encoder = create_vpx_encoder(1000, 8, 16, (320, 240))
with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock:
destroy_vpx_encoder(vpx_encoder)
+26 -28
View File
@@ -9,9 +9,9 @@ import pytest
from numpy.typing import NDArray
from starlette.websockets import WebSocketState
from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop
from facefusion.apis.stream_helper import encode_audio_loop, encode_video_loop, handle_video_stream
from facefusion.hash_helper import create_hash
from facefusion.types import VisionFrame
from facefusion.types import VideoCodec, VisionFrame
def _make_handler_websocket(events : list[Any]) -> MagicMock:
@@ -35,7 +35,7 @@ def _make_audio_packet(samples : NDArray[Any]) -> bytes:
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_run_video_encode_loop(video_codec : str) -> None:
def test_encode_video_loop(video_codec : VideoCodec) -> None:
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
@@ -45,13 +45,11 @@ def test_run_video_encode_loop(video_codec : str) -> None:
create_name = prefix + 'create_aom_encoder'
encode_name = prefix + 'encode_aom_buffer'
destroy_name = prefix + 'destroy_aom_encoder'
run_loop = run_aom_encode_loop
if video_codec == 'vp8':
create_name = prefix + 'create_vpx_encoder'
encode_name = prefix + 'encode_vpx_buffer'
destroy_name = prefix + 'destroy_vpx_encoder'
run_loop = run_vp8_encode_loop
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
vision_frame_queue.put(frame)
@@ -62,8 +60,8 @@ def test_run_video_encode_loop(video_codec : str) -> None:
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_called_once()
vision_frame_queue = queue.Queue()
@@ -77,8 +75,8 @@ def test_run_video_encode_loop(video_codec : str) -> None:
patch(destroy_name), \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
assert mock_rtc.send_video_to_peers.call_count == 3
vision_frame_queue = queue.Queue()
@@ -86,7 +84,7 @@ def test_run_video_encode_loop(video_codec : str) -> None:
with patch(create_name, return_value = MagicMock()), \
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
mock_destroy.assert_called_once()
@@ -99,7 +97,7 @@ def test_run_video_encode_loop(video_codec : str) -> None:
patch(destroy_name), \
patch(prefix + 'rtc_store'), \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
vision_frame_queue = queue.Queue()
@@ -111,8 +109,8 @@ def test_run_video_encode_loop(video_codec : str) -> None:
patch(destroy_name) as mock_destroy, \
patch(prefix + 'rtc_store') as mock_rtc_store, \
patch(prefix + 'rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_loop(vision_frame_queue, 'session-1', (64, 64))
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
assert mock_create.call_count == 2
assert mock_destroy.call_count == 2
mock_rtc.send_video_to_peers.assert_called_once()
@@ -122,12 +120,12 @@ def test_run_video_encode_loop(video_codec : str) -> None:
vision_frame_queue.put(None)
with patch(create_name, return_value = None), \
patch(prefix + 'rtc') as mock_rtc:
run_loop(vision_frame_queue, 'session-1', (64, 64))
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
mock_rtc.send_video_to_peers.assert_not_called()
# TODO: refine test
def test_run_opus_encode_loop() -> None:
def test_encode_audio_loop() -> None:
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
@@ -138,8 +136,8 @@ def test_run_opus_encode_loop() -> None:
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_called_once()
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
@@ -152,8 +150,8 @@ def test_run_opus_encode_loop() -> None:
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
run_opus_encode_loop(audio_chunk_queue, 'session-1')
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
assert mock_rtc.send_audio_to_peers.call_count == 2
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
@@ -166,7 +164,7 @@ def test_run_opus_encode_loop() -> None:
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
run_opus_encode_loop(audio_chunk_queue, 'session-1')
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
audio_chunk_queue = queue.Queue()
@@ -177,7 +175,7 @@ def test_run_opus_encode_loop() -> None:
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc_store'), \
patch('facefusion.apis.stream_helper.rtc'):
run_opus_encode_loop(audio_chunk_queue, 'session-1')
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_destroy.assert_called_once()
audio_chunk_queue = queue.Queue()
@@ -185,7 +183,7 @@ def test_run_opus_encode_loop() -> None:
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
run_opus_encode_loop(audio_chunk_queue, 'session-1')
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
mock_rtc.send_audio_to_peers.assert_not_called()
mock_destroy.assert_called_once()
@@ -202,8 +200,8 @@ def test_handle_video_stream() -> None:
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \
patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \
patch('facefusion.apis.stream_helper.encode_video_loop') as mock_loop, \
patch('facefusion.apis.stream_helper.encode_audio_loop'), \
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
asyncio.run(handle_video_stream(websocket))
websocket.accept.assert_called_once_with(subprotocol = 'proto')
@@ -211,7 +209,7 @@ def test_handle_video_stream() -> None:
websocket.close.assert_called_once()
mock_rtc.init_peers.assert_called_once_with('session-1')
mock_rtc.delete_peers.assert_called_once_with('session-1')
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
_, _, loop_session_id, loop_resolution = mock_loop.call_args[0]
assert loop_session_id == 'session-1'
assert loop_resolution == (64, 64)
@@ -232,9 +230,9 @@ def test_handle_video_stream() -> None:
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \
patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \
patch('facefusion.apis.stream_helper.encode_video_loop'), \
patch('facefusion.apis.stream_helper.encode_audio_loop') as mock_audio_loop, \
patch('facefusion.apis.stream_helper.rtc_store'):
asyncio.run(handle_video_stream(websocket))
audio_queue = mock_audio_loop.call_args[0][0]
audio_queue = mock_audio_loop.call_args[0][1]
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'