mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-02 19:01:35 +02:00
Combine encode loop methods (#1119)
* combine run_aom_encode_loop and run_vpx_encode_loop to encode_video_loop * run_opus_encode_loop -> encode_audio_loop * use else instead of continue * rename to video_codec
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import queue # TODO: try deque
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, cast, get_args
|
||||
|
||||
import cv2
|
||||
@@ -15,7 +16,7 @@ from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encod
|
||||
from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer
|
||||
from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
|
||||
from facefusion.types import AudioCodec, PeerConnection, Resolution, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame
|
||||
|
||||
|
||||
# TODO: refine this method
|
||||
@@ -24,10 +25,11 @@ async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
session_context.set_session_id(session_id)
|
||||
stream_codec : VideoCodec = 'av1'
|
||||
video_codec : VideoCodec = 'av1'
|
||||
audio_codec : AudioCodec = 'opus'
|
||||
|
||||
if websocket.query_params.get('codec') in get_args(VideoCodec):
|
||||
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
|
||||
video_codec = cast(VideoCodec, websocket.query_params.get('codec'))
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
@@ -51,12 +53,8 @@ async def handle_video_stream(websocket : WebSocket) -> None:
|
||||
|
||||
event_loop = asyncio.get_running_loop()
|
||||
|
||||
if stream_codec == 'av1':
|
||||
video_encode_task = event_loop.run_in_executor(None, run_aom_encode_loop, vision_frame_queue, session_id, resolution)
|
||||
if stream_codec == 'vp8':
|
||||
video_encode_task = event_loop.run_in_executor(None, run_vp8_encode_loop, vision_frame_queue, session_id, resolution)
|
||||
|
||||
audio_encode_task = event_loop.run_in_executor(None, run_opus_encode_loop, audio_chunk_queue, session_id)
|
||||
video_encode_task = event_loop.run_in_executor(None, encode_video_loop, video_codec, vision_frame_queue, session_id, resolution)
|
||||
audio_encode_task = event_loop.run_in_executor(None, encode_audio_loop, audio_codec, audio_chunk_queue, session_id)
|
||||
await websocket.send_text('ready')
|
||||
|
||||
async for frame_type, frame_buffer in stream_frames:
|
||||
@@ -138,21 +136,29 @@ def connect_rtc(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAns
|
||||
return None
|
||||
|
||||
|
||||
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
|
||||
def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
|
||||
aom_encoder = create_aom_encoder(frame_resolution, 4500, 8, 10)
|
||||
def encode_video_loop(video_codec : VideoCodec, vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
|
||||
create_encoder = partial(create_aom_encoder, 4500, 8, 10)
|
||||
destroy_encoder = destroy_aom_encoder
|
||||
encode_buffer = encode_aom_buffer
|
||||
|
||||
if video_codec == 'vp8':
|
||||
create_encoder = partial(create_vpx_encoder, 4500, 8, 16)
|
||||
destroy_encoder = destroy_vpx_encoder # type:ignore[assignment]
|
||||
encode_buffer = encode_vpx_buffer # type:ignore[assignment]
|
||||
|
||||
encoder = create_encoder(frame_resolution)
|
||||
temp_resolution = frame_resolution
|
||||
timestamp = 0
|
||||
|
||||
vision_frame = vision_frame_queue.get()
|
||||
|
||||
while numpy.any(vision_frame) and aom_encoder:
|
||||
while numpy.any(vision_frame) and encoder:
|
||||
output_vision_frame = process_vision_frame(vision_frame)
|
||||
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
|
||||
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_aom_buffer(aom_encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
output_frame_buffer = encode_buffer(encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if output_frame_buffer and rtc_peers:
|
||||
@@ -161,55 +167,17 @@ def run_aom_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]],
|
||||
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
#TODO: we are not using continue as control flow in the project
|
||||
continue
|
||||
else:
|
||||
destroy_encoder(encoder)
|
||||
temp_resolution = output_resolution
|
||||
encoder = create_encoder(temp_resolution)
|
||||
timestamp = 0
|
||||
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
temp_resolution = output_resolution
|
||||
aom_encoder = create_aom_encoder(temp_resolution, 4500, 8, 10)
|
||||
timestamp = 0
|
||||
|
||||
if aom_encoder:
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
if encoder:
|
||||
destroy_encoder(encoder)
|
||||
|
||||
|
||||
# TODO: switch to loop_encode_video or encode_video_loop ... pass video_codec to follow standards
|
||||
def run_vp8_encode_loop(vision_frame_queue : queue.Queue[Optional[VisionFrame]], session_id : SessionId, frame_resolution : Resolution) -> None:
|
||||
vpx_encoder = create_vpx_encoder(frame_resolution, 4500, 8, 16)
|
||||
temp_resolution = frame_resolution
|
||||
timestamp = 0
|
||||
|
||||
vision_frame = vision_frame_queue.get()
|
||||
|
||||
while numpy.any(vision_frame) and vpx_encoder:
|
||||
output_vision_frame = process_vision_frame(vision_frame)
|
||||
output_resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
|
||||
|
||||
if output_resolution == temp_resolution:
|
||||
output_frame_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
output_frame_buffer = encode_vpx_buffer(vpx_encoder, output_frame_buffer, output_resolution, timestamp)
|
||||
rtc_peers = rtc_store.get_peers(session_id)
|
||||
|
||||
if output_frame_buffer and rtc_peers:
|
||||
video_timestamp = int(time.monotonic() * 90000)
|
||||
rtc.send_video_to_peers(rtc_peers, output_frame_buffer, video_timestamp)
|
||||
|
||||
timestamp += 1
|
||||
vision_frame = vision_frame_queue.get()
|
||||
# TODO: we are not using continue as control flow in the project
|
||||
continue
|
||||
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
temp_resolution = output_resolution
|
||||
vpx_encoder = create_vpx_encoder(temp_resolution, 4500, 8, 16)
|
||||
timestamp = 0
|
||||
|
||||
if vpx_encoder:
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
|
||||
|
||||
# TODO: switch to loop_encode_audio or encode_audio_loop ... pass audio_codec to follow standards
|
||||
def run_opus_encode_loop(audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
|
||||
def encode_audio_loop(audio_codec : AudioCodec, audio_chunk_queue : queue.Queue[Optional[bytes]], session_id : SessionId) -> None:
|
||||
opus_encoder = create_opus_encoder(48000, 2)
|
||||
audio_timestamp = 0
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from facefusion.libraries import aom as aom_module
|
||||
from facefusion.types import AomEncoder, BitRate, Resolution
|
||||
|
||||
|
||||
def create_aom_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[AomEncoder]:
|
||||
def create_aom_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[AomEncoder]:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
|
||||
@@ -6,7 +6,7 @@ from facefusion.libraries import vpx as vpx_module
|
||||
from facefusion.types import BitRate, Resolution, VpxEncoder
|
||||
|
||||
|
||||
def create_vpx_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, cpu_count : int) -> Optional[VpxEncoder]:
|
||||
def create_vpx_encoder(bitrate: BitRate, thread_count: int, cpu_count: int, frame_resolution: Resolution) -> Optional[VpxEncoder]:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
|
||||
@@ -23,15 +23,15 @@ def before_all() -> None:
|
||||
|
||||
|
||||
def test_create_aom_encoder() -> None:
|
||||
assert create_aom_encoder((320, 240), 1000, 8, 16)
|
||||
assert create_aom_encoder((0, 0), 0, 0, 0) is None
|
||||
assert create_aom_encoder(1000, 8, 16, (320, 240))
|
||||
assert create_aom_encoder(0, 0, 0, (0, 0)) is None
|
||||
|
||||
|
||||
def test_encode_aom_buffer() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
aom_encoder = create_aom_encoder(video_resolution, 1000, 1, 0)
|
||||
aom_encoder = create_aom_encoder(1000, 1, 0, video_resolution)
|
||||
|
||||
if is_linux() or is_windows():
|
||||
assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31'
|
||||
@@ -41,7 +41,7 @@ def test_encode_aom_buffer() -> None:
|
||||
|
||||
|
||||
def test_destroy_aom_encoder() -> None:
|
||||
aom_encoder = create_aom_encoder((320, 240), 1000, 8, 16)
|
||||
aom_encoder = create_aom_encoder(1000, 8, 16, (320, 240))
|
||||
|
||||
with patch.object(aom_module.create_static_library(), 'aom_codec_destroy') as mock:
|
||||
destroy_aom_encoder(aom_encoder)
|
||||
|
||||
@@ -23,15 +23,15 @@ def before_all() -> None:
|
||||
|
||||
|
||||
def test_create_vpx_encoder() -> None:
|
||||
assert create_vpx_encoder((320, 240), 1000, 8, 16)
|
||||
assert create_vpx_encoder((0, 0), 0, 0, 0) is None
|
||||
assert create_vpx_encoder(1000, 8, 16, (320, 240))
|
||||
assert create_vpx_encoder(0, 0, 0, (0, 0)) is None
|
||||
|
||||
|
||||
def test_encode_vpx_buffer() -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
video_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
|
||||
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
vpx_encoder = create_vpx_encoder(video_resolution, 1000, 1, 0)
|
||||
vpx_encoder = create_vpx_encoder(1000, 1, 0, video_resolution)
|
||||
|
||||
if is_linux() or is_windows():
|
||||
assert create_hash(encode_vpx_buffer(vpx_encoder, video_buffer, video_resolution, 3)) == 'ce133a1f'
|
||||
@@ -41,7 +41,7 @@ def test_encode_vpx_buffer() -> None:
|
||||
|
||||
|
||||
def test_destroy_vpx_encoder() -> None:
|
||||
vpx_encoder = create_vpx_encoder((320, 240), 1000, 8, 16)
|
||||
vpx_encoder = create_vpx_encoder(1000, 8, 16, (320, 240))
|
||||
|
||||
with patch.object(vpx_module.create_static_library(), 'vpx_codec_destroy') as mock:
|
||||
destroy_vpx_encoder(vpx_encoder)
|
||||
|
||||
+26
-28
@@ -9,9 +9,9 @@ import pytest
|
||||
from numpy.typing import NDArray
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from facefusion.apis.stream_helper import handle_video_stream, run_aom_encode_loop, run_opus_encode_loop, run_vp8_encode_loop
|
||||
from facefusion.apis.stream_helper import encode_audio_loop, encode_video_loop, handle_video_stream
|
||||
from facefusion.hash_helper import create_hash
|
||||
from facefusion.types import VisionFrame
|
||||
from facefusion.types import VideoCodec, VisionFrame
|
||||
|
||||
|
||||
def _make_handler_websocket(events : list[Any]) -> MagicMock:
|
||||
@@ -35,7 +35,7 @@ def _make_audio_packet(samples : NDArray[Any]) -> bytes:
|
||||
|
||||
|
||||
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
|
||||
def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
def test_encode_video_loop(video_codec : VideoCodec) -> None:
|
||||
frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
small_frame = numpy.full((64, 64, 3), 128, dtype = numpy.uint8)
|
||||
large_frame = numpy.full((128, 128, 3), 128, dtype = numpy.uint8)
|
||||
@@ -45,13 +45,11 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
create_name = prefix + 'create_aom_encoder'
|
||||
encode_name = prefix + 'encode_aom_buffer'
|
||||
destroy_name = prefix + 'destroy_aom_encoder'
|
||||
run_loop = run_aom_encode_loop
|
||||
|
||||
if video_codec == 'vp8':
|
||||
create_name = prefix + 'create_vpx_encoder'
|
||||
encode_name = prefix + 'encode_vpx_buffer'
|
||||
destroy_name = prefix + 'destroy_vpx_encoder'
|
||||
run_loop = run_vp8_encode_loop
|
||||
|
||||
vision_frame_queue : queue.Queue[Optional[VisionFrame]] = queue.Queue()
|
||||
vision_frame_queue.put(frame)
|
||||
@@ -62,8 +60,8 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_called_once()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
@@ -77,8 +75,8 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
assert mock_rtc.send_video_to_peers.call_count == 3
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
@@ -86,7 +84,7 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
with patch(create_name, return_value = MagicMock()), \
|
||||
patch(destroy_name) as mock_destroy, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
@@ -99,7 +97,7 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
patch(destroy_name), \
|
||||
patch(prefix + 'rtc_store'), \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
|
||||
vision_frame_queue = queue.Queue()
|
||||
@@ -111,8 +109,8 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
patch(destroy_name) as mock_destroy, \
|
||||
patch(prefix + 'rtc_store') as mock_rtc_store, \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
assert mock_create.call_count == 2
|
||||
assert mock_destroy.call_count == 2
|
||||
mock_rtc.send_video_to_peers.assert_called_once()
|
||||
@@ -122,12 +120,12 @@ def test_run_video_encode_loop(video_codec : str) -> None:
|
||||
vision_frame_queue.put(None)
|
||||
with patch(create_name, return_value = None), \
|
||||
patch(prefix + 'rtc') as mock_rtc:
|
||||
run_loop(vision_frame_queue, 'session-1', (64, 64))
|
||||
encode_video_loop(video_codec, vision_frame_queue, 'session-1', (64, 64))
|
||||
mock_rtc.send_video_to_peers.assert_not_called()
|
||||
|
||||
|
||||
# TODO: refine test
|
||||
def test_run_opus_encode_loop() -> None:
|
||||
def test_encode_audio_loop() -> None:
|
||||
audio_chunk = numpy.zeros(1920, dtype = numpy.float32).tobytes()
|
||||
|
||||
audio_chunk_queue : queue.Queue[Optional[bytes]] = queue.Queue()
|
||||
@@ -138,8 +136,8 @@ def test_run_opus_encode_loop() -> None:
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_called_once()
|
||||
assert mock_rtc.send_audio_to_peers.call_args[0][2] == 0
|
||||
|
||||
@@ -152,8 +150,8 @@ def test_run_opus_encode_loop() -> None:
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc_store, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
mock_rtc_store.get_rtc_peers.return_value = [ MagicMock() ]
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
mock_rtc_store.get_peers.return_value = [ MagicMock() ]
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
assert mock_rtc.send_audio_to_peers.call_count == 2
|
||||
assert mock_rtc.send_audio_to_peers.call_args_list[0][0][2] == 0
|
||||
assert mock_rtc.send_audio_to_peers.call_args_list[1][0][2] == 960
|
||||
@@ -166,7 +164,7 @@ def test_run_opus_encode_loop() -> None:
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'), \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_not_called()
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
@@ -177,7 +175,7 @@ def test_run_opus_encode_loop() -> None:
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'), \
|
||||
patch('facefusion.apis.stream_helper.rtc'):
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
audio_chunk_queue = queue.Queue()
|
||||
@@ -185,7 +183,7 @@ def test_run_opus_encode_loop() -> None:
|
||||
with patch('facefusion.apis.stream_helper.create_opus_encoder', return_value = MagicMock()), \
|
||||
patch('facefusion.apis.stream_helper.destroy_opus_encoder') as mock_destroy, \
|
||||
patch('facefusion.apis.stream_helper.rtc') as mock_rtc:
|
||||
run_opus_encode_loop(audio_chunk_queue, 'session-1')
|
||||
encode_audio_loop('opus', audio_chunk_queue, 'session-1')
|
||||
mock_rtc.send_audio_to_peers.assert_not_called()
|
||||
mock_destroy.assert_called_once()
|
||||
|
||||
@@ -202,8 +200,8 @@ def test_handle_video_stream() -> None:
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
||||
patch('facefusion.apis.stream_helper.run_aom_encode_loop') as mock_loop, \
|
||||
patch('facefusion.apis.stream_helper.run_opus_encode_loop'), \
|
||||
patch('facefusion.apis.stream_helper.encode_video_loop') as mock_loop, \
|
||||
patch('facefusion.apis.stream_helper.encode_audio_loop'), \
|
||||
patch('facefusion.apis.stream_helper.rtc_store') as mock_rtc:
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
websocket.accept.assert_called_once_with(subprotocol = 'proto')
|
||||
@@ -211,7 +209,7 @@ def test_handle_video_stream() -> None:
|
||||
websocket.close.assert_called_once()
|
||||
mock_rtc.init_peers.assert_called_once_with('session-1')
|
||||
mock_rtc.delete_peers.assert_called_once_with('session-1')
|
||||
_, loop_session_id, loop_resolution = mock_loop.call_args[0]
|
||||
_, _, loop_session_id, loop_resolution = mock_loop.call_args[0]
|
||||
assert loop_session_id == 'session-1'
|
||||
assert loop_resolution == (64, 64)
|
||||
|
||||
@@ -232,9 +230,9 @@ def test_handle_video_stream() -> None:
|
||||
patch('facefusion.apis.stream_helper.session_manager.find_session_id', return_value = 'session-1'), \
|
||||
patch('facefusion.apis.stream_helper.session_context.set_session_id'), \
|
||||
patch('facefusion.apis.stream_helper.state_manager.get_item', return_value = 30), \
|
||||
patch('facefusion.apis.stream_helper.run_aom_encode_loop'), \
|
||||
patch('facefusion.apis.stream_helper.run_opus_encode_loop') as mock_audio_loop, \
|
||||
patch('facefusion.apis.stream_helper.encode_video_loop'), \
|
||||
patch('facefusion.apis.stream_helper.encode_audio_loop') as mock_audio_loop, \
|
||||
patch('facefusion.apis.stream_helper.rtc_store'):
|
||||
asyncio.run(handle_video_stream(websocket))
|
||||
audio_queue = mock_audio_loop.call_args[0][0]
|
||||
audio_queue = mock_audio_loop.call_args[0][1]
|
||||
assert create_hash(audio_queue.get_nowait()) == '6d72f0fc'
|
||||
|
||||
Reference in New Issue
Block a user