Fix stream lifecycle bugs: threading, RTP sync, and resource cleanup (#1125)

* fix executor thread not terminating after stream deletion

* fix stream shutdown and thread lifecycle

* add todo

* cleanup

* cleanup

* cleanup

* cleanup

* audio_queue.put() → get_nowait() + put_nowait()

* rename test

* fix test

* merge tests

* cleanup tests

* cleanup tests

* cleanup tests

* simplify test logic with mock

* cleanup

* cleanup hard to read stream_helper.py

* introduce rtc_peer.has_peers

* fix lint

* add todos

* fix test hash
This commit is contained in:
Harisreedhar
2026-05-22 16:29:45 +05:30
committed by GitHub
parent 520dcbfd6b
commit 4fe79483ea
3 changed files with 193 additions and 259 deletions
+86 -94
View File
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import ctypes
import queue
@@ -15,7 +14,7 @@ from facefusion import rtc, rtc_store, state_manager, streamer
from facefusion.audio import create_empty_audio_frame
from facefusion.codecs import aom_decoder, aom_encoder, opus_decoder, opus_encoder, vpx_decoder, vpx_encoder
from facefusion.libraries import datachannel as datachannel_module
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, OpusDecoder, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
async def process_image(websocket : WebSocket) -> None:
@@ -35,9 +34,8 @@ async def process_image(websocket : WebSocket) -> None:
#TODO: needs review
def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
video_codec : VideoCodec = 'vp8'
av1_payload_type = rtc.get_payload_type(sdp_offer, 'av1')
if av1_payload_type:
if rtc.get_payload_type(sdp_offer, 'av1'):
video_codec = 'av1'
video_payload_type = rtc.get_payload_type(sdp_offer, video_codec)
@@ -82,15 +80,14 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA
rtc_store.init_peers(session_id)
rtc_store.get_peers(session_id).append(rtc_peer)
event_loop = asyncio.get_event_loop()
event_loop.run_in_executor(None, run_peer_loop, session_id, rtc_peer)
threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True).start()
return local_sdp
return local_sdp
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
return None
#TODO: needs review
async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFrame]:
websocket_event = await websocket.receive()
@@ -106,17 +103,20 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr
#TODO: needs review
def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
video_codec = rtc_peer.get('video').get('codec')
video_track = rtc_peer.get('video').get('receiver_track')
stop_event = threading.Event()
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
receiver_threads = [threading.Thread(target = receive_video_frames, args = (video_track, video_codec, video_queue, stop_event), daemon = True)]
receiver_threads = []
video_codec = rtc_peer.get('video').get('codec')
video_track = rtc_peer.get('video').get('receiver_track')
video_receiver_thread = threading.Thread(target = receive_video_frames, args = (video_track, video_codec, video_queue), daemon = True)
receiver_threads.append(video_receiver_thread)
if rtc_peer.get('audio'):
audio_codec = rtc_peer.get('audio').get('codec')
audio_codec : AudioCodec = 'opus'
audio_track = rtc_peer.get('audio').get('receiver_track')
receiver_threads.append(threading.Thread(target = receive_audio_frames, args = (audio_track, audio_codec, audio_queue, stop_event), daemon = True))
audio_receiver_thread = threading.Thread(target = receive_audio_frames, args = (audio_track, audio_codec, audio_queue), daemon = True)
receiver_threads.append(audio_receiver_thread)
for receiver_thread in receiver_threads:
receiver_thread.start()
@@ -138,16 +138,17 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
output_vision_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
send_timestamp = time.monotonic()
if output_resolution == temp_resolution:
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
else:
destroy_video_encoder(video_codec, video_encoder)
destroy_video_encoder(video_codec, video_encoder) # TODO: remove unconditional destroy methods, which have no impact on control flow
temp_resolution = output_resolution
video_encoder = create_video_encoder(video_codec, temp_resolution)
frame_index = 0
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
send_timestamp = time.monotonic()
if output_video_buffer:
rtc.send_video(rtc_peer, output_video_buffer, int(send_timestamp * 90000))
@@ -160,110 +161,67 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
frame_index += 1
temp_vision_frame = video_queue.get()
destroy_video_encoder(video_codec, video_encoder)
destroy_video_encoder(video_codec, video_encoder) # TODO: remove unconditional destroy methods, which have no impact on control flow
opus_encoder.destroy(audio_encoder)
stop_event.set()
for receiver_thread in receiver_threads:
receiver_thread.join()
rtc_store.delete_peers(session_id)
def receive_video_frames(video_track : int, video_codec : VideoCodec, video_queue : queue.Queue[VisionFrame], stop_event : threading.Event) -> None:
def receive_video_frames(video_track : int, video_codec : VideoCodec, video_queue : queue.Queue[VisionFrame]) -> None:
datachannel_library = datachannel_module.create_static_library()
video_decoder = create_video_decoder(video_codec)
receive_buffer = ctypes.create_string_buffer(512 * 1024)
receive_status_code = -3
while not stop_event.is_set(): # TODO: use positive while condition
frame_buffer = receive_video_buffer(datachannel_library, video_track, receive_buffer)
while receive_status_code == 0 or receive_status_code == -3:
buffer_size = ctypes.c_int(512 * 1024)
receive_status_code = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
if frame_buffer:
if receive_status_code == 0 and buffer_size.value > 0:
frame_buffer = receive_buffer.raw[:buffer_size.value]
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer)
if numpy.any(vision_frame):
with contextlib.suppress(queue.Empty):
video_queue.get_nowait()
video_queue.put_nowait(vision_frame)
else:
stop_event.wait(timeout = 0.001) # TODO: remove this timeout
if receive_status_code == -3:
time.sleep(0.001) # TODO: remove sleep
video_queue.put(numpy.empty(0))
if video_codec == 'av1':
aom_decoder.destroy(video_decoder)
if video_codec == 'vp8':
vpx_decoder.destroy(video_decoder)
destroy_video_decoder(video_codec, video_decoder)
def receive_audio_frames(audio_track : int, audio_codec : AudioCodec, audio_queue : queue.Queue[AudioFrame], stop_event : threading.Event) -> None:
def receive_audio_frames(audio_track : int, audio_codec : AudioCodec, audio_queue : queue.Queue[AudioFrame]) -> None:
datachannel_library = datachannel_module.create_static_library()
audio_decoder = opus_decoder.create(48000, 2)
receive_buffer = ctypes.create_string_buffer(8 * 1024)
receive_status_code = -3
while not stop_event.is_set(): # TODO: use positive while condition
audio_frame = receive_audio_frame(datachannel_library, audio_track, audio_decoder, receive_buffer)
while receive_status_code == 0 or receive_status_code == -3:
buffer_size = ctypes.c_int(8 * 1024)
receive_status_code = datachannel_library.rtcReceiveMessage(audio_track, receive_buffer, ctypes.byref(buffer_size))
if audio_frame.dtype == numpy.float32:
audio_queue.put(audio_frame)
else:
stop_event.wait(timeout = 0.001) # TODO: remove this timeout
if receive_status_code == 0 and buffer_size.value > 0:
opus_buffer = receive_buffer.raw[:buffer_size.value]
output_buffer = opus_decoder.decode(audio_decoder, opus_buffer, 960, 2)
if output_buffer:
with contextlib.suppress(queue.Empty):
audio_queue.get_nowait()
audio_queue.put_nowait(numpy.frombuffer(output_buffer, dtype = numpy.float32))
if receive_status_code == -3:
time.sleep(0.001) # TODO: remove sleep
opus_decoder.destroy(audio_decoder)
#TODO: needs review
def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomDecoder]:
if video_codec == 'av1':
return aom_decoder.create(8)
if video_codec == 'vp8':
return vpx_decoder.create(8)
return None
#TODO: needs review - remove as both are the same
def create_video_encoder(video_codec : VideoCodec, resolution : Resolution) -> Optional[VpxEncoder | AomEncoder]:
if video_codec == 'av1':
return aom_encoder.create(resolution, 8000, 8, 10)
if video_codec == 'vp8':
return vpx_encoder.create(resolution, 8000, 8, 10)
return None
#TODO: needs review - remove as this is a trivial helper
def destroy_video_encoder(video_codec : VideoCodec, video_encoder : Optional[VpxEncoder | AomEncoder]) -> None:
if video_codec == 'av1':
aom_encoder.destroy(video_encoder)
if video_codec == 'vp8':
vpx_encoder.destroy(video_encoder)
def destroy_stream(session_id : SessionId) -> bool:
if rtc_store.get_peers(session_id):
rtc_store.delete_peers(session_id)
return True
return False
#TODO: needs review
def receive_audio_frame(datachannel_library : ctypes.CDLL, audio_track : int, audio_decoder : OpusDecoder, receive_buffer : ctypes.Array[ctypes.c_char]) -> AudioFrame:
buffer_size = ctypes.c_int(8 * 1024)
receive_output = datachannel_library.rtcReceiveMessage(audio_track, receive_buffer, ctypes.byref(buffer_size))
if receive_output == 0 and buffer_size.value > 0:
opus_buffer = receive_buffer.raw[:buffer_size.value]
output_buffer = opus_decoder.decode(audio_decoder, opus_buffer, 960, 2)
if output_buffer:
return numpy.frombuffer(output_buffer, dtype = numpy.float32)
return create_empty_audio_frame()
def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes) -> Optional[VisionFrame]:
if video_codec == 'av1':
aom_pointer = aom_decoder.decode(video_decoder, frame_buffer)
@@ -294,11 +252,45 @@ def encode_video_frame(video_codec : VideoCodec, video_encoder : VpxEncoder | Ao
return bytes()
def receive_video_buffer(datachannel_library : ctypes.CDLL, video_track : int, receive_buffer : ctypes.Array[ctypes.c_char]) -> Optional[bytes]:
buffer_size = ctypes.c_int(512 * 1024)
receive_output = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomDecoder]:
if video_codec == 'av1':
return aom_decoder.create(8)
if receive_output == 0 and buffer_size.value > 0:
return receive_buffer.raw[:buffer_size.value]
if video_codec == 'vp8':
return vpx_decoder.create(8)
return None
def create_video_encoder(video_codec : VideoCodec, resolution : Resolution) -> Optional[VpxEncoder | AomEncoder]:
if video_codec == 'av1':
return aom_encoder.create(resolution, 8000, 8, 10)
if video_codec == 'vp8':
return vpx_encoder.create(resolution, 8000, 8, 10)
return None
def destroy_video_decoder(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder) -> None:
if video_codec == 'av1':
aom_decoder.destroy(video_decoder)
if video_codec == 'vp8':
vpx_decoder.destroy(video_decoder)
def destroy_video_encoder(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder) -> None:
if video_codec == 'av1':
aom_encoder.destroy(video_encoder)
if video_codec == 'vp8':
vpx_encoder.destroy(video_encoder)
def destroy_stream(session_id : SessionId) -> bool:
if rtc_store.has_peers(session_id):
rtc_store.delete_peers(session_id)
return not rtc_store.has_peers(session_id)
return False
+4
View File
@@ -25,5 +25,9 @@ def delete_peers(session_id : SessionId) -> None:
return None
def has_peers(session_id : SessionId) -> bool:
return bool(RTC_STORE.get(session_id))
def clear() -> None:
RTC_STORE.clear()
+103 -165
View File
@@ -1,4 +1,3 @@
import ctypes
import queue
import threading
from unittest.mock import AsyncMock, MagicMock, patch
@@ -12,8 +11,10 @@ from tests.assert_helper import get_test_example_file, get_test_examples_directo
from facefusion import rtc, rtc_store, state_manager
from facefusion.apis.endpoints.stream import websocket_stream
from facefusion.apis.stream_helper import decode_video_frame, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop
from facefusion.codecs import aom_decoder, aom_encoder, opus_encoder, vpx_decoder, vpx_encoder
from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder
from facefusion.common_helper import is_linux, is_macos, is_windows
from facefusion.download import conditional_download
from facefusion.hash_helper import create_hash
from facefusion.libraries import aom as aom_module, datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
from facefusion.types import AudioFrame, RtcPeer, VideoCodec, VisionFrame
from facefusion.vision import read_video_frame
@@ -42,121 +43,79 @@ def before_each() -> None:
# TODO: refine test
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_decode_video_frame(video_codec : VideoCodec) -> None:
@pytest.mark.parametrize('video_codec', ['av1', 'vp8'])
def test_decode_video_frame(video_codec: VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
yuv_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
frame_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
if video_codec == 'av1':
encoded_buffer = aom_encoder.encode(aom_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
decoded_frame = decode_video_frame(video_codec, aom_decoder.create(8), encoded_buffer)
encode_frame_buffer = aom_encoder.encode(aom_encoder.create(video_resolution, 1000, 1, 0), frame_buffer, video_resolution, 0)
decode_frame_buffer = decode_video_frame(video_codec, aom_decoder.create(8), encode_frame_buffer).tobytes()
assert decoded_frame is not None
assert decoded_frame.shape[1] >= video_resolution[0]
assert decoded_frame.shape[0] >= video_resolution[1]
assert decoded_frame.ndim == 3
if is_linux() or is_windows():
assert create_hash(decode_frame_buffer) == '299b6ad6'
if is_macos():
assert create_hash(decode_frame_buffer) == '9f463b13'
assert decode_video_frame('av1', aom_decoder.create(8), bytes()) is None
if video_codec == 'vp8':
encoded_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
decoded_frame = decode_video_frame(video_codec, vpx_decoder.create(8), encoded_buffer)
encode_frame_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), frame_buffer, video_resolution, 0)
decode_frame_buffer = decode_video_frame(video_codec, vpx_decoder.create(8), encode_frame_buffer).tobytes()
assert decoded_frame is not None
assert decoded_frame.shape[1] == video_resolution[0]
assert decoded_frame.shape[0] == video_resolution[1]
assert decoded_frame.ndim == 3
if is_linux() or is_windows():
assert create_hash(decode_frame_buffer) == '99ef2c25'
if is_macos():
assert create_hash(decode_frame_buffer) == 'ff3ecb43'
assert decode_video_frame('vp8', vpx_decoder.create(8), bytes()) is None
# TODO: refine test
def test_decode_video_frame_empty_buffer() -> None:
assert decode_video_frame('vp8', vpx_decoder.create(8), bytes()) is None
assert decode_video_frame('av1', aom_decoder.create(8), bytes()) is None
# TODO: refine test
def test_pump_video_frames_keeps_latest_when_full() -> None:
source_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (source_frame.shape[1], source_frame.shape[0])
yuv_buffer = cv2.cvtColor(source_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
encoded_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
mock_lib = MagicMock()
state : list[int] = [ 0 ]
def receive_two(track : int, buffer : ctypes.Array[ctypes.c_char], size_byref : ctypes.c_void_p) -> int:
if state[0] < 2:
ctypes.memmove(buffer, encoded_buffer, len(encoded_buffer))
ctypes.cast(size_byref, ctypes.POINTER(ctypes.c_int))[0] = len(encoded_buffer)
state[0] += 1
return 0
return -1
mock_lib.rtcReceiveMessage.side_effect = receive_two
def test_receive_video_frames() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
stop_event = threading.Event()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib):
receiver = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue, stop_event), daemon = True)
receiver.start()
receiver.join(timeout = 2.0)
stop_event.set()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.decode_video_frame', return_value = vision_frame):
receiver_thread = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue), daemon = True)
receiver_thread.start()
receiver_thread.join(timeout = 2.0)
assert video_queue.qsize() == 1
assert video_queue.get_nowait().shape[1] == video_resolution[0]
if is_linux() or is_windows():
assert create_hash(video_queue.get_nowait().tobytes()) == 'a17439db'
if is_macos():
assert create_hash(video_queue.get_nowait().tobytes()) == '38d00e2a'
# TODO: refine test
def test_pump_audio_frames_delivers_decoded_frame() -> None:
audio_data = numpy.zeros(960 * 2, dtype = numpy.float32).tobytes()
encoded_opus = opus_encoder.encode(opus_encoder.create(48000, 2), audio_data, 960)
mock_lib = MagicMock()
state : list[bool] = [ False ]
def receive_once(track : int, buffer : ctypes.Array[ctypes.c_char], size_byref : ctypes.c_void_p) -> int:
if state[0]:
return -1
ctypes.memmove(buffer, encoded_opus, len(encoded_opus))
ctypes.cast(size_byref, ctypes.POINTER(ctypes.c_int))[0] = len(encoded_opus)
state[0] = True
return 0
mock_lib.rtcReceiveMessage.side_effect = receive_once
def test_receive_audio_frames() -> None:
audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
stop_event = threading.Event()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib):
receiver = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue, stop_event), daemon = True)
receiver.start()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.opus_decoder.decode', return_value = audio_frame.tobytes()):
receiver_thread = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue), daemon = True)
receiver_thread.start()
audio_frame = audio_queue.get(timeout = 2.0)
stop_event.set()
receiver.join()
receiver_thread.join(timeout = 1.0)
assert audio_frame.dtype == numpy.float32
assert audio_frame.size == 960 * 2
# TODO: refine test
def test_pump_audio_frames_skips_empty_frames() -> None:
mock_lib = MagicMock()
mock_lib.rtcReceiveMessage.return_value = -1
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
stop_event = threading.Event()
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib):
receiver = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue, stop_event), daemon = True)
receiver.start()
threading.Event().wait(timeout = 0.05)
stop_event.set()
receiver.join()
assert audio_queue.empty()
# TODO: refine test
def test_run_peer_loop_processes_and_sends_frame() -> None:
def test_run_peer_loop() -> None:
source_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (source_frame.shape[1], source_frame.shape[0])
yuv_buffer = cv2.cvtColor(source_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
encoded_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), yuv_buffer, video_resolution, 0)
peer_connection = rtc.create_peer_connection()
video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96)
@@ -176,20 +135,11 @@ def test_run_peer_loop_processes_and_sends_frame() -> None:
rtc_store.init_peers(session_id)
rtc_store.get_peers(session_id).append(rtc_peer)
mock_lib = MagicMock()
state : list[bool] = [ False ]
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
def receive_once(track : int, buffer : ctypes.Array[ctypes.c_char], size_byref : ctypes.c_void_p) -> int:
if state[0]:
return -1
ctypes.memmove(buffer, encoded_buffer, len(encoded_buffer))
ctypes.cast(size_byref, ctypes.POINTER(ctypes.c_int))[0] = len(encoded_buffer)
state[0] = True
return 0
mock_lib.rtcReceiveMessage.side_effect = receive_once
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = mock_lib), \
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.decode_video_frame', return_value = source_frame), \
patch('facefusion.apis.stream_helper.rtc.send_video') as mock_send_video:
thread = threading.Thread(target = run_peer_loop, args = (session_id, rtc_peer), daemon = True)
thread.start()
@@ -201,21 +151,32 @@ def test_run_peer_loop_processes_and_sends_frame() -> None:
# TODO: refine test
@pytest.mark.anyio
async def test_receive_vision_frames_yields_decoded_frames() -> None:
async def test_receive_vision_frames() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
_, jpeg_buffer = cv2.imencode('.jpg', vision_frame)
jpeg_bytes = jpeg_buffer.tobytes()
mock_ws = AsyncMock()
mock_ws.receive.side_effect =\
frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes()
websocket_mock = AsyncMock()
websocket_mock.receive.side_effect =\
[
{'type': 'websocket.receive', 'bytes': jpeg_bytes},
{'type': 'websocket.receive', 'bytes': jpeg_bytes},
{'type': 'websocket.disconnect'}
{
'type': 'websocket.receive',
'bytes': frame_buffer
},
{
'type': 'websocket.receive',
'bytes': b'invalid'
},
{
'type': 'websocket.receive',
'bytes': frame_buffer
},
{
'type': 'websocket.disconnect'
}
]
frames = []
async for frame in receive_vision_frames(mock_ws):
async for frame in receive_vision_frames(websocket_mock):
frames.append(frame)
assert len(frames) == 2
@@ -224,56 +185,34 @@ async def test_receive_vision_frames_yields_decoded_frames() -> None:
# TODO: refine test
@pytest.mark.anyio
async def test_receive_vision_frames_skips_invalid_bytes() -> None:
mock_ws = AsyncMock()
mock_ws.receive.side_effect =\
[
{'type': 'websocket.receive', 'bytes': b'not_a_jpeg'},
{'type': 'websocket.disconnect'}
]
frames = []
async for frame in receive_vision_frames(mock_ws):
frames.append(frame)
assert len(frames) == 0
# TODO: refine test
@pytest.mark.anyio
async def test_process_image_sends_processed_frame() -> None:
async def test_process_image() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
_, jpeg_buffer = cv2.imencode('.jpg', vision_frame)
mock_ws = AsyncMock()
mock_ws.receive.side_effect = [{'type': 'websocket.receive', 'bytes': jpeg_buffer.tobytes()}]
frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes()
websocket_mock = AsyncMock()
websocket_mock.receive.side_effect = [{'type': 'websocket.receive', 'bytes': frame_buffer}]
state_manager.init_item('source_paths', [get_test_example_file('source.jpg')])
await process_image(websocket_mock)
await process_image(mock_ws)
mock_ws.send_bytes.assert_called_once()
assert mock_ws.send_bytes.call_args[0][0][:3] == b'\xff\xd8\xff'
# TODO: refine test
@pytest.mark.anyio
async def test_process_image_without_source_skips_send() -> None:
mock_ws = AsyncMock()
websocket_mock.send_bytes.assert_called_once()
assert websocket_mock.send_bytes.call_args[0][0][:3] == b'\xff\xd8\xff'
state_manager.init_item('source_paths', None)
await process_image(websocket_mock)
await process_image(mock_ws)
mock_ws.send_bytes.assert_not_called()
websocket_mock.send_bytes.assert_called_once()
# TODO: refine test
@pytest.mark.anyio
async def test_websocket_stream_accepts_and_closes() -> None:
mock_ws = AsyncMock()
mock_ws.scope = {'type': 'websocket', 'headers': []}
mock_ws.client_state = WebSocketState.CONNECTED
async def test_websocket_stream() -> None:
websocket_mock = AsyncMock()
websocket_mock.scope =\
{
'type': 'websocket',
'headers': []
}
websocket_mock.client_state = WebSocketState.CONNECTED
state_manager.init_item('source_paths', None)
@@ -281,31 +220,30 @@ async def test_websocket_stream_accepts_and_closes() -> None:
patch('facefusion.apis.endpoints.stream.extract_access_token', return_value = None), \
patch('facefusion.apis.endpoints.stream.session_manager.find_session_id', return_value = None), \
patch('facefusion.apis.endpoints.stream.session_context.set_session_id'):
await websocket_stream(mock_ws)
await websocket_stream(websocket_mock)
mock_ws.accept.assert_called_once()
mock_ws.close.assert_called_once()
websocket_mock.accept.assert_called_once()
websocket_mock.close.assert_called_once()
# TODO: refine test
@pytest.mark.anyio
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
async def test_process_video_returns_sdp_answer(video_codec : VideoCodec) -> None:
sender_connection = rtc.create_peer_connection()
@pytest.mark.parametrize('video_codec, session_id', [ ('av1', 'test-process-video-av1'), ('vp8', 'test-process-video-vp8') ])
def test_process_video(video_codec : VideoCodec, session_id : str) -> None:
peer_connection = rtc.create_peer_connection()
if video_codec == 'av1':
rtc.add_video_track(sender_connection, 'sendrecv', video_codec, 35)
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 35)
if video_codec == 'vp8':
rtc.add_video_track(sender_connection, 'sendrecv', video_codec, 96)
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 96)
rtc.add_audio_track(sender_connection, 'sendrecv', 'opus', 111)
sdp_offer = rtc.create_sdp_offer(sender_connection)
datachannel_module.create_static_library().rtcDeletePeerConnection(sender_connection)
rtc.add_audio_track(peer_connection, 'sendrecv', 'opus', 111)
sdp_offer = rtc.create_sdp_offer(peer_connection)
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
with patch('facefusion.apis.stream_helper.run_peer_loop'):
sdp_answer = process_video('test-process-video-' + video_codec, sdp_offer)
with patch('facefusion.apis.stream_helper.threading.Thread'):
sdp_answer = process_video(session_id, sdp_offer)
assert sdp_answer is not None
assert sdp_answer
assert 'm=video' in sdp_answer
assert 'a=recvonly' in sdp_answer
assert 'a=sendonly' in sdp_answer