Refine stream helper: queue-based loops, AV1/WHIP support, endpoint separation (#1124)

* rearrange methods

* add test_stream_helper.py

* improve tests

* use deque

* move decoder to recieve methods

* remove cleanup_peer

* add destroy_stream

* make run_peer_loop more readable

* make video and audio method simlar

* change deque to queue to avoid extra thread event

* remove negative condition

* cleanup

* remove wait_for_frame

* cleanup

* cleanup

* fix process_image

* fix lint

* cleanup

* remove last_time

* add todos
This commit is contained in:
Harisreedhar
2026-05-20 23:43:14 +05:30
committed by GitHub
parent 48869bedf0
commit 520dcbfd6b
3 changed files with 509 additions and 229 deletions
+13 -2
View File
@@ -1,15 +1,26 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_404_NOT_FOUND
from starlette.websockets import WebSocket
from starlette.websockets import WebSocket, WebSocketState
from facefusion import session_context, session_manager
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.apis.session_helper import extract_access_token
from facefusion.apis.stream_helper import destroy_stream, process_image, process_video
# TODO: can we avoid passing websocket? just the data if doable
async def websocket_stream(websocket : WebSocket) -> None:
return await process_image(websocket)
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
await websocket.accept(subprotocol = subprotocol)
await process_image(websocket)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
async def post_stream(request : Request) -> Response:
+185 -227
View File
@@ -1,33 +1,26 @@
import asyncio
import contextlib
import ctypes
import queue
import threading
import time
from collections.abc import AsyncIterator
from typing import Optional
import cv2
import numpy
from starlette.websockets import WebSocket, WebSocketState
from starlette.websockets import WebSocket
from facefusion import rtc, rtc_store, session_context, session_manager, state_manager, streamer
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.apis.session_helper import extract_access_token
from facefusion import rtc, rtc_store, state_manager, streamer
from facefusion.audio import create_empty_audio_frame
from facefusion.codecs import aom_decoder, aom_encoder, opus_decoder, opus_encoder, vpx_decoder, vpx_encoder
from facefusion.libraries import datachannel as datachannel_module
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, OpusDecoder, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
#TODO: needs review
async def process_image(websocket : WebSocket) -> None:
#TODO: all the websocket handling belongs to the endpoint, these are connection concerns
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
capture_vision_frame = await anext(receive_vision_frames(websocket), None)
@@ -38,8 +31,63 @@ async def process_image(websocket : WebSocket) -> None:
if is_success:
await websocket.send_bytes(output_frame_buffer.tobytes())
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
#TODO: needs review
def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
video_codec : VideoCodec = 'vp8'
av1_payload_type = rtc.get_payload_type(sdp_offer, 'av1')
if av1_payload_type:
video_codec = 'av1'
video_payload_type = rtc.get_payload_type(sdp_offer, video_codec)
if video_payload_type:
peer_connection : PeerConnection = rtc.create_peer_connection()
video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', video_codec, video_payload_type)
video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', video_codec, video_payload_type)
audio_codec : AudioCodec = 'opus'
audio_payload_type = rtc.get_payload_type(sdp_offer, audio_codec)
audio_receiver_track = None
audio_sender_track = None
if audio_payload_type:
audio_receiver_track = rtc.add_audio_track(peer_connection, 'recvonly', audio_codec, audio_payload_type)
audio_sender_track = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, audio_payload_type)
rtc.set_remote_description(peer_connection, sdp_offer)
local_sdp = rtc.create_sdp_answer(peer_connection)
if local_sdp:
rtc_peer : RtcPeer =\
{
'peer_connection': peer_connection,
'video':
{
'sender_track': video_sender_track,
'receiver_track': video_receiver_track,
'codec': video_codec
}
}
if audio_receiver_track and audio_sender_track:
rtc_peer['audio'] =\
{
'sender_track': audio_sender_track,
'receiver_track': audio_receiver_track,
'codec': audio_codec
}
rtc_store.init_peers(session_id)
rtc_store.get_peers(session_id).append(rtc_peer)
event_loop = asyncio.get_event_loop()
event_loop.run_in_executor(None, run_peer_loop, session_id, rtc_peer)
return local_sdp
return None
#TODO: needs review
@@ -56,164 +104,115 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr
websocket_event = await websocket.receive()
#TODO: just exist as endpoint stream.py is not allowed to access rtc store directly
def destroy_stream(session_id : SessionId) -> bool:
rtc_store.delete_peers(session_id)
return not rtc_store.get_peers(session_id)
#TODO: needs review
def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
video_codec : VideoCodec = 'vp8'
av1_payload_type = rtc.get_payload_type(sdp_offer, 'av1')
if av1_payload_type:
video_codec = 'av1'
video_payload_type = rtc.get_payload_type(sdp_offer, video_codec)
if not video_payload_type:
return None
peer_connection : PeerConnection = rtc.create_peer_connection()
video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', video_codec, video_payload_type)
video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', video_codec, video_payload_type)
audio_codec : AudioCodec = 'opus'
audio_payload_type = rtc.get_payload_type(sdp_offer, audio_codec)
audio_receiver_track = None
audio_sender_track = None
if audio_payload_type:
audio_receiver_track = rtc.add_audio_track(peer_connection, 'recvonly', audio_codec, audio_payload_type)
audio_sender_track = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, audio_payload_type)
rtc.set_remote_description(peer_connection, sdp_offer)
local_sdp = rtc.create_sdp_answer(peer_connection)
if local_sdp:
rtc_peer : RtcPeer =\
{
'peer_connection': peer_connection,
'video':
{
'sender_track': video_sender_track,
'receiver_track': video_receiver_track,
'codec': video_codec
}
}
if audio_receiver_track and audio_sender_track:
rtc_peer['audio'] =\
{
'sender_track': audio_sender_track,
'receiver_track': audio_receiver_track,
'codec': audio_codec
}
rtc_store.init_peers(session_id)
rtc_store.get_peers(session_id).append(rtc_peer)
event_loop = asyncio.get_event_loop()
event_loop.run_in_executor(None, run_peer_loop, session_id, rtc_peer)
return local_sdp
#TODO: needs review
def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
datachannel_library = datachannel_module.create_static_library()
video_info = rtc_peer.get('video')
video_codec = video_info.get('codec')
video_decoder = create_video_decoder(video_codec)
audio_info = rtc_peer.get('audio')
audio_decoder = opus_decoder.create(48000, 2) if audio_info else None
video_receive_buffer = ctypes.create_string_buffer(512 * 1024)
audio_receive_buffer = ctypes.create_string_buffer(8 * 1024)
video_codec = rtc_peer.get('video').get('codec')
video_track = rtc_peer.get('video').get('receiver_track')
stop_event = threading.Event()
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
receiver_threads = [threading.Thread(target = receive_video_frames, args = (video_track, video_codec, video_queue, stop_event), daemon = True)]
frame_buffer = poll_for_buffer(datachannel_library, video_info.get('receiver_track'), video_receive_buffer, 30.0)
if rtc_peer.get('audio'):
audio_codec = rtc_peer.get('audio').get('codec')
audio_track = rtc_peer.get('audio').get('receiver_track')
receiver_threads.append(threading.Thread(target = receive_audio_frames, args = (audio_track, audio_codec, audio_queue, stop_event), daemon = True))
if frame_buffer is None:
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
return
for receiver_thread in receiver_threads:
receiver_thread.start()
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer)
temp_vision_frame = video_queue.get()
if vision_frame is None:
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
return
if numpy.any(temp_vision_frame):
audio_frame = create_empty_audio_frame()
temp_resolution : Resolution = (temp_vision_frame.shape[1], temp_vision_frame.shape[0])
video_encoder = create_video_encoder(video_codec, temp_resolution)
audio_encoder = opus_encoder.create(48000, 2)
frame_index = 0
audio_frame = create_empty_audio_frame()
resolution : Resolution = (vision_frame.shape[1], vision_frame.shape[0])
video_encoder = create_video_encoder(video_codec, resolution)
audio_encoder = opus_encoder.create(48000, 2)
frame_index = 0
while numpy.any(temp_vision_frame):
with contextlib.suppress(queue.Empty):
audio_frame = audio_queue.get_nowait()
while True:
if audio_info and audio_decoder:
audio_frame = receive_audio_frame(datachannel_library, audio_info.get('receiver_track'), audio_decoder, audio_receive_buffer)
output_vision_frame = streamer.process_frame(audio_frame, temp_vision_frame)
output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
output_vision_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
output_vision_frame = streamer.process_frame(audio_frame, vision_frame)
output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
if output_resolution == temp_resolution:
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
else:
destroy_video_encoder(video_codec, video_encoder)
temp_resolution = output_resolution
video_encoder = create_video_encoder(video_codec, temp_resolution)
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
if output_resolution != resolution:
resolution = output_resolution
destroy_video_encoder(video_codec, video_encoder)
video_encoder = create_video_encoder(video_codec, resolution)
send_timestamp = time.monotonic()
raw_vision_frame = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420)
if output_video_buffer:
rtc.send_video(rtc_peer, output_video_buffer, int(send_timestamp * 90000))
if video_codec == 'av1':
encoded_video_buffer = aom_encoder.encode(video_encoder, raw_vision_frame.tobytes(), resolution, frame_index)
if video_codec == 'vp8':
encoded_video_buffer = vpx_encoder.encode(video_encoder, raw_vision_frame.tobytes(), resolution, frame_index)
if audio_encoder and audio_frame.dtype == numpy.float32:
output_audio_buffer = opus_encoder.encode(audio_encoder, audio_frame.tobytes(), 960)
if encoded_video_buffer:
video_timestamp = int(time.monotonic() * 90000)
rtc.send_video(rtc_peer, encoded_video_buffer, video_timestamp)
if output_audio_buffer:
rtc.send_audio(rtc_peer, output_audio_buffer, int(send_timestamp * 48000))
if audio_encoder and audio_frame is not None and audio_frame.size > 0:
encoded_audio_buffer = opus_encoder.encode(audio_encoder, audio_frame.tobytes(), 960)
frame_index += 1
temp_vision_frame = video_queue.get()
if encoded_audio_buffer:
audio_timestamp = int(time.monotonic() * 48000)
rtc.send_audio(rtc_peer, encoded_audio_buffer, audio_timestamp)
destroy_video_encoder(video_codec, video_encoder)
opus_encoder.destroy(audio_encoder)
frame_index += 1
stop_event.set()
next_frame = drain_to_latest_frame(datachannel_library, video_info.get('receiver_track'), video_codec, video_decoder, video_receive_buffer)
if next_frame is not None:
vision_frame = next_frame
continue
next_frame = poll_for_frame(datachannel_library, video_info.get('receiver_track'), video_codec, video_decoder, video_receive_buffer, 30.0)
if next_frame is None:
break
vision_frame = next_frame
destroy_video_encoder(video_codec, video_encoder)
opus_encoder.destroy(audio_encoder)
cleanup_peer(session_id, rtc_peer, video_codec, video_decoder, audio_decoder)
#TODO: needs review
def cleanup_peer(session_id : SessionId, rtc_peer : RtcPeer, video_codec : VideoCodec, video_decoder : Optional[VpxDecoder | AomDecoder], audio_decoder : Optional[OpusDecoder]) -> None:
if video_decoder:
if video_codec == 'av1':
aom_decoder.destroy(video_decoder)
if video_codec == 'vp8':
vpx_decoder.destroy(video_decoder)
if audio_decoder:
opus_decoder.destroy(audio_decoder)
for receiver_thread in receiver_threads:
receiver_thread.join()
rtc_store.delete_peers(session_id)
def receive_video_frames(video_track : int, video_codec : VideoCodec, video_queue : queue.Queue[VisionFrame], stop_event : threading.Event) -> None:
datachannel_library = datachannel_module.create_static_library()
video_decoder = create_video_decoder(video_codec)
receive_buffer = ctypes.create_string_buffer(512 * 1024)
while not stop_event.is_set(): # TODO: use positive while condition
frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer)
if frame_buffer:
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer)
if numpy.any(vision_frame):
with contextlib.suppress(queue.Empty):
video_queue.get_nowait()
video_queue.put_nowait(vision_frame)
else:
stop_event.wait(timeout = 0.001) # TODO: remove this timeout
video_queue.put(numpy.empty(0))
if video_codec == 'av1':
aom_decoder.destroy(video_decoder)
if video_codec == 'vp8':
vpx_decoder.destroy(video_decoder)
def receive_audio_frames(audio_track : int, audio_codec : AudioCodec, audio_queue : queue.Queue[AudioFrame], stop_event : threading.Event) -> None:
datachannel_library = datachannel_module.create_static_library()
audio_decoder = opus_decoder.create(48000, 2)
receive_buffer = ctypes.create_string_buffer(8 * 1024)
while not stop_event.is_set(): # TODO: use positive while condition
audio_frame = receive_audio_frame(datachannel_library, audio_track, audio_decoder, receive_buffer)
if audio_frame.dtype == numpy.float32:
audio_queue.put(audio_frame)
else:
stop_event.wait(timeout = 0.001) # TODO: remove this timeout
opus_decoder.destroy(audio_decoder)
#TODO: needs review
def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomDecoder]:
if video_codec == 'av1':
@@ -242,21 +241,12 @@ def destroy_video_encoder(video_codec : VideoCodec, video_encoder : Optional[Vpx
vpx_encoder.destroy(video_encoder)
def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes) -> Optional[VisionFrame]:
if video_codec == 'av1':
aom_pointer = aom_decoder.decode(video_decoder, frame_buffer)
if aom_pointer:
frame_width, frame_height = aom_pointer.get('resolution')
yuv_frame = numpy.frombuffer(aom_pointer.get('buffer'), dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
return cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
if video_codec == 'vp8':
vpx_pointer = vpx_decoder.decode(video_decoder, frame_buffer)
if vpx_pointer:
frame_width, frame_height = vpx_pointer.get('resolution')
yuv_frame = numpy.frombuffer(vpx_pointer.get('buffer'), dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
return cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
def destroy_stream(session_id : SessionId) -> bool:
if rtc_store.get_peers(session_id):
rtc_store.delete_peers(session_id)
return True
return None
return False
#TODO: needs review
@@ -274,6 +264,36 @@ def receive_audio_frame(datachannel_library : ctypes.CDLL, audio_track : int, au
return create_empty_audio_frame()
def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes) -> Optional[VisionFrame]:
if video_codec == 'av1':
aom_pointer = aom_decoder.decode(video_decoder, frame_buffer)
if aom_pointer:
frame_width, frame_height = aom_pointer.get('resolution')
vision_frame = numpy.frombuffer(aom_pointer.get('buffer'), dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
return cv2.cvtColor(vision_frame, cv2.COLOR_YUV2BGR_I420)
if video_codec == 'vp8':
vpx_pointer = vpx_decoder.decode(video_decoder, frame_buffer)
if vpx_pointer:
frame_width, frame_height = vpx_pointer.get('resolution')
vision_frame = numpy.frombuffer(vpx_pointer.get('buffer'), dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
return cv2.cvtColor(vision_frame, cv2.COLOR_YUV2BGR_I420)
return None
def encode_video_frame(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, raw_frame_bytes : bytes, resolution : Resolution, frame_index : int) -> bytes:
if video_codec == 'av1':
return aom_encoder.encode(video_encoder, raw_frame_bytes, resolution, frame_index)
if video_codec == 'vp8':
return vpx_encoder.encode(video_encoder, raw_frame_bytes, resolution, frame_index)
return bytes()
def receive_video_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char]) -> Optional[bytes]:
buffer_size = ctypes.c_int(512 * 1024)
receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
@@ -282,65 +302,3 @@ def receive_video_buffer(datachannel_library : ctypes.CDLL, video_track : int, r
return receive_buffer.raw[:buffer_size.value]
return None
def poll_for_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char], timeout : float) -> Optional[bytes]:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer)
if frame_buffer is not None:
return frame_buffer
time.sleep(0.001)
return None
#TODO: needs review
def poll_for_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char], timeout : float) -> Optional[VisionFrame]:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
vision_frame = try_receive_frame(datachannel_library, video_track, video_codec, video_decoder, receive_buffer)
if vision_frame is not None:
return vision_frame
time.sleep(0.001)
return None
#TODO: needs review
def try_receive_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char]) -> Optional[VisionFrame]:
frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer)
if frame_buffer:
return decode_video_frame(video_codec, video_decoder, frame_buffer)
return None
#TODO: needs review
def drain_to_latest_frame(datachannel_library : ctypes.CDLL, video_track : int, video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, receive_buffer : ctypes.Array[ctypes.c_char]) -> Optional[VisionFrame]:
last_vision_frame = numpy.empty(0)
buffer_size = ctypes.c_int(512 * 1024)
receive_output = 0
while receive_output == 0 and buffer_size.value > 0:
buffer_size.value = 512 * 1024
receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
if receive_output == 0 and buffer_size.value > 0:
frame_buffer = receive_buffer.raw[:buffer_size.value]
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer)
if numpy.any(vision_frame):
last_vision_frame = vision_frame
if numpy.any(last_vision_frame):
return last_vision_frame
return None
+311
View File
@@ -0,0 +1,311 @@
import ctypes
import queue
import threading
from unittest.mock import AsyncMock, MagicMock, patch
import cv2
import numpy
import pytest
from starlette.websockets import WebSocketState
from tests.assert_helper import get_test_example_file, get_test_examples_directory
from facefusion import rtc, rtc_store, state_manager
from facefusion.apis.endpoints.stream import websocket_stream
from facefusion.apis.stream_helper import decode_video_frame, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop
from facefusion.codecs import aom_decoder, aom_encoder, opus_encoder, vpx_decoder, vpx_encoder
from facefusion.download import conditional_download
from facefusion.libraries import aom as aom_module, datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
from facefusion.types import AudioFrame, RtcPeer, VideoCodec, VisionFrame
from facefusion.vision import read_video_frame
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
state_manager.init_item('download_providers', [ 'github', 'huggingface' ])
state_manager.init_item('processors', [])
conditional_download(get_test_examples_directory(),
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4',
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg'
])
aom_module.pre_check()
vpx_module.pre_check()
opus_module.pre_check()
datachannel_module.pre_check()
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
rtc_store.clear()
# TODO: refine test
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_decode_video_frame(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
yuv_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
if video_codec == 'av1':
encoded_buffer = aom_encoder.encode(aom_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
decoded_frame = decode_video_frame(video_codec, aom_decoder.create(8), encoded_buffer)
assert decoded_frame is not None
assert decoded_frame.shape[1] >= video_resolution[0]
assert decoded_frame.shape[0] >= video_resolution[1]
assert decoded_frame.ndim == 3
if video_codec == 'vp8':
encoded_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
decoded_frame = decode_video_frame(video_codec, vpx_decoder.create(8), encoded_buffer)
assert decoded_frame is not None
assert decoded_frame.shape[1] == video_resolution[0]
assert decoded_frame.shape[0] == video_resolution[1]
assert decoded_frame.ndim == 3
# TODO: refine test
def test_decode_video_frame_empty_buffer() -> None:
assert decode_video_frame('vp8', vpx_decoder.create(8), bytes()) is None
assert decode_video_frame('av1', aom_decoder.create(8), bytes()) is None
# TODO: refine test
def test_pump_video_frames_keeps_latest_when_full() -> None:
source_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (source_frame.shape[1], source_frame.shape[0])
yuv_buffer = cv2.cvtColor(source_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
encoded_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
mock_lib = MagicMock()
state : list[int] = [ 0 ]
def receive_two(track : int, buffer : ctypes.Array[ctypes.c_char], size_byref : ctypes.c_void_p) -> int:
if state[0] < 2:
ctypes.memmove(buffer, encoded_buffer, len(encoded_buffer))
ctypes.cast(size_byref, ctypes.POINTER(ctypes.c_int))[0] = len(encoded_buffer)
state[0] += 1
return 0
return -1
mock_lib.rtcReceiveMessage.side_effect = receive_two
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
stop_event = threading.Event()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib):
receiver = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue, stop_event), daemon = True)
receiver.start()
receiver.join(timeout = 2.0)
stop_event.set()
assert video_queue.qsize() == 1
assert video_queue.get_nowait().shape[1] == video_resolution[0]
# TODO: refine test
def test_pump_audio_frames_delivers_decoded_frame() -> None:
audio_data = numpy.zeros(960 * 2, dtype = numpy.float32).tobytes()
encoded_opus = opus_encoder.encode(opus_encoder.create(48000, 2), audio_data, 960)
mock_lib = MagicMock()
state : list[bool] = [ False ]
def receive_once(track : int, buffer : ctypes.Array[ctypes.c_char], size_byref : ctypes.c_void_p) -> int:
if state[0]:
return -1
ctypes.memmove(buffer, encoded_opus, len(encoded_opus))
ctypes.cast(size_byref, ctypes.POINTER(ctypes.c_int))[0] = len(encoded_opus)
state[0] = True
return 0
mock_lib.rtcReceiveMessage.side_effect = receive_once
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
stop_event = threading.Event()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib):
receiver = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue, stop_event), daemon = True)
receiver.start()
audio_frame = audio_queue.get(timeout = 2.0)
stop_event.set()
receiver.join()
assert audio_frame.dtype == numpy.float32
assert audio_frame.size == 960 * 2
# TODO: refine test
def test_pump_audio_frames_skips_empty_frames() -> None:
mock_lib = MagicMock()
mock_lib.rtcReceiveMessage.return_value = -1
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
stop_event = threading.Event()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib):
receiver = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue, stop_event), daemon = True)
receiver.start()
threading.Event().wait(timeout = 0.05)
stop_event.set()
receiver.join()
assert audio_queue.empty()
# TODO: refine test
def test_run_peer_loop_processes_and_sends_frame() -> None:
source_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (source_frame.shape[1], source_frame.shape[0])
yuv_buffer = cv2.cvtColor(source_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
encoded_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
peer_connection = rtc.create_peer_connection()
video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96)
video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', 'vp8', 96)
rtc_peer : RtcPeer =\
{
'peer_connection': peer_connection,
'video':
{
'sender_track': video_sender_track,
'receiver_track': video_receiver_track,
'codec': 'vp8'
}
}
session_id = 'test-run-peer-loop'
rtc_store.init_peers(session_id)
rtc_store.get_peers(session_id).append(rtc_peer)
mock_lib = MagicMock()
state : list[bool] = [ False ]
def receive_once(track : int, buffer : ctypes.Array[ctypes.c_char], size_byref : ctypes.c_void_p) -> int:
if state[0]:
return -1
ctypes.memmove(buffer, encoded_buffer, len(encoded_buffer))
ctypes.cast(size_byref, ctypes.POINTER(ctypes.c_int))[0] = len(encoded_buffer)
state[0] = True
return 0
mock_lib.rtcReceiveMessage.side_effect = receive_once
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib), \
patch('facefusion.apis.stream_helper.rtc.send_video') as mock_send_video:
thread = threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True)
thread.start()
thread.join(timeout = 5.0)
assert mock_send_video.called
assert len(mock_send_video.call_args[0][1]) > 0
# TODO: refine test
@pytest.mark.anyio
async def test_receive_vision_frames_yields_decoded_frames() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
_, jpeg_buffer = cv2.imencode('.jpg', vision_frame)
jpeg_bytes = jpeg_buffer.tobytes()
mock_ws = AsyncMock()
mock_ws.receive.side_effect =\
[
{'type': 'websocket.receive', 'bytes': jpeg_bytes},
{'type': 'websocket.receive', 'bytes': jpeg_bytes},
{'type': 'websocket.disconnect'}
]
frames = []
async for frame in receive_vision_frames(mock_ws):
frames.append(frame)
assert len(frames) == 2
assert frames[0].shape == vision_frame.shape
# TODO: refine test
@pytest.mark.anyio
async def test_receive_vision_frames_skips_invalid_bytes() -> None:
mock_ws = AsyncMock()
mock_ws.receive.side_effect =\
[
{'type': 'websocket.receive', 'bytes': b'not_a_jpeg'},
{'type': 'websocket.disconnect'}
]
frames = []
async for frame in receive_vision_frames(mock_ws):
frames.append(frame)
assert len(frames) == 0
# TODO: refine test
@pytest.mark.anyio
async def test_process_image_sends_processed_frame() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
_, jpeg_buffer = cv2.imencode('.jpg', vision_frame)
mock_ws = AsyncMock()
mock_ws.receive.side_effect = [{'type': 'websocket.receive', 'bytes': jpeg_buffer.tobytes()}]
state_manager.init_item('source_paths', [get_test_example_file('source.jpg')])
await process_image(mock_ws)
mock_ws.send_bytes.assert_called_once()
assert mock_ws.send_bytes.call_args[0][0][:3] == b'\xff\xd8\xff'
# TODO: refine test
@pytest.mark.anyio
async def test_process_image_without_source_skips_send() -> None:
mock_ws = AsyncMock()
state_manager.init_item('source_paths', None)
await process_image(mock_ws)
mock_ws.send_bytes.assert_not_called()
# TODO: refine test
@pytest.mark.anyio
async def test_websocket_stream_accepts_and_closes() -> None:
mock_ws = AsyncMock()
mock_ws.scope = {'type': 'websocket', 'headers': []}
mock_ws.client_state = WebSocketState.CONNECTED
state_manager.init_item('source_paths', None)
with patch('facefusion.apis.endpoints.stream.get_sec_websocket_protocol', return_value = None), \
patch('facefusion.apis.endpoints.stream.extract_access_token', return_value = None), \
patch('facefusion.apis.endpoints.stream.session_manager.find_session_id', return_value = None), \
patch('facefusion.apis.endpoints.stream.session_context.set_session_id'):
await websocket_stream(mock_ws)
mock_ws.accept.assert_called_once()
mock_ws.close.assert_called_once()
# TODO: refine test
@pytest.mark.anyio
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
async def test_process_video_returns_sdp_answer(video_codec : VideoCodec) -> None:
sender_connection = rtc.create_peer_connection()
if video_codec == 'av1':
rtc.add_video_track(sender_connection, 'sendrecv', video_codec, 35)
if video_codec == 'vp8':
rtc.add_video_track(sender_connection, 'sendrecv', video_codec, 96)
rtc.add_audio_track(sender_connection, 'sendrecv', 'opus', 111)
sdp_offer = rtc.create_sdp_offer(sender_connection)
datachannel_module.create_static_library().rtcDeletePeerConnection(sender_connection)
with patch('facefusion.apis.stream_helper.run_peer_loop'):
sdp_answer = process_video('test-process-video-' + video_codec, sdp_offer)
assert sdp_answer is not None
assert 'm=video' in sdp_answer
assert 'a=recvonly' in sdp_answer
assert 'a=sendonly' in sdp_answer