Add available event (#1134)

* remove sleep with available event

* reorder methods, caller has to follow variable names of consumer, reorder tests methods

* more todos for naming
This commit is contained in:
Henry Ruhs
2026-05-30 13:09:48 +02:00
committed by GitHub
parent 1ac0e3e9a4
commit 460c65004b
7 changed files with 246 additions and 219 deletions
-1
View File
@@ -9,7 +9,6 @@ from facefusion.apis.session_helper import extract_access_token
from facefusion.apis.stream_helper import destroy_stream, process_image, process_video
# TODO: can we avoid passing websocket? just the data if doable
async def websocket_stream(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
+48 -23
View File
@@ -4,6 +4,7 @@ import queue
import threading
import time
from collections.abc import AsyncIterator
from functools import partial
from typing import Optional
import cv2
@@ -14,7 +15,7 @@ from facefusion import rtc, rtc_store, state_manager, streamer
from facefusion.audio import create_empty_audio_frame
from facefusion.codecs import aom_decoder, aom_encoder, opus_decoder, opus_encoder, vpx_decoder, vpx_encoder
from facefusion.libraries import datachannel as datachannel_module
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, BitRate, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, BitRate, PeerConnection, Resolution, RtcPeer, RtcPeerAudio, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
async def process_image(websocket : WebSocket) -> None:
@@ -51,6 +52,7 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA
audio_codec : AudioCodec = 'opus'
audio_payload_type = rtc.get_payload_type(sdp_offer, audio_codec)
#todo we try to avoid empty variables like that
audio_receiver_track = None
audio_sender_track = None
@@ -76,12 +78,11 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA
}
if audio_receiver_track and audio_sender_track:
rtc_peer['audio'] =\
{
'sender_track': audio_sender_track,
'receiver_track': audio_receiver_track,
'codec': audio_codec
}
rtc_peer['audio'] = RtcPeerAudio(
sender_track = audio_sender_track,
receiver_track = audio_receiver_track,
codec = audio_codec
)
rtc_store.init_peers(session_id)
rtc_store.get_peers(session_id).append(rtc_peer)
@@ -108,6 +109,7 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr
#TODO: needs review
#TODO: method is too complex
def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
# TODO: combine video and audio queue
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
@@ -144,13 +146,15 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
output_vision_frame = streamer.process_frame(audio_frame, temp_vision_frame)
output_resolution : Resolution = (output_vision_frame.shape[1], output_vision_frame.shape[0])
# TODO: align buffer naming with input/output and video/audio convention
output_vision_buffer = cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
send_timestamp = time.monotonic()
peer_bitrate = rtc_peer.get('sender_bitrate').value
if output_resolution != temp_resolution: # TODO avoid != in condition
# TODO: avoid != in condition
if output_resolution != temp_resolution:
destroy_video_encoder(video_codec, video_encoder)
temp_resolution = output_resolution
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
@@ -178,7 +182,8 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
frame_index += 1
temp_vision_frame = video_queue.get()
destroy_video_encoder(video_codec, video_encoder) # TODO: remove unconditional destroy methods, which have no impact on control flow
# TODO: remove unconditional destroy methods, which have no impact on control flow
destroy_video_encoder(video_codec, video_encoder)
opus_encoder.destroy(audio_encoder)
rtc.clear_remb(rtc_peer)
@@ -188,10 +193,16 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
rtc_store.delete_peers(session_id)
# TODO: method is too complex
def receive_video_frames(video_track : int, video_codec : VideoCodec, video_queue : queue.Queue[VisionFrame]) -> None:
datachannel_library = datachannel_module.create_static_library()
video_decoder = create_video_decoder(video_codec)
receive_buffer = ctypes.create_string_buffer(512 * 1024)
# todo - could be prepare ready event
available_event = threading.Event()
available_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)(partial(dispatch_event, available_event))
datachannel_library.rtcSetAvailableCallback(video_track, available_callback)
# todo off
receive_status_code = -3
while receive_status_code == 0 or receive_status_code == -3:
@@ -199,8 +210,8 @@ def receive_video_frames(video_track : int, video_codec : VideoCodec, video_queu
receive_status_code = datachannel_library.rtcReceiveMessage(video_track, receive_buffer, ctypes.byref(buffer_size))
if receive_status_code == 0 and buffer_size.value > 0:
# TODO: align buffer naming with input/output and video/audio convention
frame_buffer = receive_buffer.raw[:buffer_size.value]
#TODO: throttle decode to stream video fps or 30fps with todo
vision_frame = decode_video_frame(video_codec, video_decoder, frame_buffer)
if numpy.any(vision_frame):
@@ -209,17 +220,24 @@ def receive_video_frames(video_track : int, video_codec : VideoCodec, video_queu
video_queue.put_nowait(vision_frame)
if receive_status_code == -3:
# TODO: use rtcSetMessageCallback instead of polling
time.sleep(0.001)
available_event.wait()
available_event.clear()
video_queue.put(numpy.empty(0))
destroy_video_decoder(video_codec, video_decoder)
# TODO: audio_codec is not used but has to, even if there is just one
# TODO: method is too complex
def receive_audio_frames(audio_track : int, audio_codec : AudioCodec, audio_queue : queue.Queue[AudioFrame]) -> None:
datachannel_library = datachannel_module.create_static_library()
audio_decoder = opus_decoder.create(48000, 2)
receive_buffer = ctypes.create_string_buffer(8 * 1024)
#todo - could be prepare ready event
available_event = threading.Event()
available_callback = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)(partial(dispatch_event, available_event))
datachannel_library.rtcSetAvailableCallback(audio_track, available_callback)
#todo off
receive_status_code = -3
while receive_status_code == 0 or receive_status_code == -3:
@@ -227,6 +245,7 @@ def receive_audio_frames(audio_track : int, audio_codec : AudioCodec, audio_queu
receive_status_code = datachannel_library.rtcReceiveMessage(audio_track, receive_buffer, ctypes.byref(buffer_size))
if receive_status_code == 0 and buffer_size.value > 0:
# TODO: rename opus_buffer and output_buffer to audio convention
opus_buffer = receive_buffer.raw[:buffer_size.value]
output_buffer = opus_decoder.decode(audio_decoder, opus_buffer, 960, 2)
@@ -237,38 +256,40 @@ def receive_audio_frames(audio_track : int, audio_codec : AudioCodec, audio_queu
audio_queue.put_nowait(numpy.frombuffer(output_buffer, dtype = numpy.float32))
if receive_status_code == -3:
# TODO: use rtcSetMessageCallback instead of polling
time.sleep(0.001)
available_event.wait()
available_event.clear()
opus_decoder.destroy(audio_decoder)
def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, frame_buffer : bytes) -> Optional[VisionFrame]:
def decode_video_frame(video_codec : VideoCodec, video_decoder : VpxDecoder | AomDecoder, input_buffer : bytes) -> Optional[VisionFrame]:
if video_codec == 'av1':
aom_pointer = aom_decoder.decode(video_decoder, frame_buffer)
aom_pointer = aom_decoder.decode(video_decoder, input_buffer)
if aom_pointer:
frame_width, frame_height = aom_pointer.get('resolution')
# TODO: move reshape and cvtColor into decoder modules
vision_frame = numpy.frombuffer(aom_pointer.get('buffer'), dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
return cv2.cvtColor(vision_frame, cv2.COLOR_YUV2BGR_I420)
if video_codec == 'vp8':
vpx_pointer = vpx_decoder.decode(video_decoder, frame_buffer)
vpx_pointer = vpx_decoder.decode(video_decoder, input_buffer)
if vpx_pointer:
frame_width, frame_height = vpx_pointer.get('resolution')
# TODO: move reshape and cvtColor into decoder modules
vision_frame = numpy.frombuffer(vpx_pointer.get('buffer'), dtype = numpy.uint8).reshape((frame_height * 3 // 2, frame_width))
return cv2.cvtColor(vision_frame, cv2.COLOR_YUV2BGR_I420)
return None
def encode_video_frame(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, raw_frame_bytes : bytes, resolution : Resolution, frame_index : int) -> bytes:
def encode_video_frame(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, input_buffer : bytes, frame_resolution : Resolution, frame_index : int) -> bytes:
if video_codec == 'av1':
return aom_encoder.encode(video_encoder, raw_frame_bytes, resolution, frame_index)
return aom_encoder.encode(video_encoder, input_buffer, frame_resolution, frame_index)
if video_codec == 'vp8':
return vpx_encoder.encode(video_encoder, raw_frame_bytes, resolution, frame_index)
return vpx_encoder.encode(video_encoder, input_buffer, frame_resolution, frame_index)
return bytes()
@@ -283,12 +304,12 @@ def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomD
return None
def create_video_encoder(video_codec : VideoCodec, resolution : Resolution, bitrate : BitRate) -> Optional[VpxEncoder | AomEncoder]:
def create_video_encoder(video_codec : VideoCodec, frame_resolution : Resolution, bitrate : BitRate) -> Optional[VpxEncoder | AomEncoder]:
if video_codec == 'av1':
return aom_encoder.create(resolution, bitrate, 8, 10)
return aom_encoder.create(frame_resolution, bitrate, 8, 10)
if video_codec == 'vp8':
return vpx_encoder.create(resolution, bitrate, 8, 10)
return vpx_encoder.create(frame_resolution, bitrate, 8, 10)
return None
@@ -325,3 +346,7 @@ def destroy_stream(session_id : SessionId) -> bool:
return not rtc_store.has_peers(session_id)
return False
def dispatch_event(event : threading.Event, track : int, pointer : ctypes.c_void_p) -> None:
event.set()
+3
View File
@@ -224,6 +224,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
library.rtcChainRembHandler.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_uint, ctypes.c_void_p) ]
library.rtcChainRembHandler.restype = ctypes.c_int
library.rtcSetAvailableCallback.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p) ]
library.rtcSetAvailableCallback.restype = ctypes.c_int
return library
+11 -11
View File
@@ -216,6 +216,17 @@ def create_video_track_init(media_direction : MediaDirection, video_codec : Vide
return ctypes.byref(track_init)
def get_payload_type(sdp_offer : SdpOffer, codec : AudioCodec | VideoCodec) -> int:
datachannel_library = datachannel_module.create_static_library()
payload_type_buffer = (ctypes.c_int * 16)()
payload_type_total = datachannel_library.rtcGetPayloadTypesForCodec(sdp_offer.encode(), codec.lower().encode(), payload_type_buffer, 16)
if payload_type_total:
return payload_type_buffer[0]
return 0
@ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_uint, ctypes.c_void_p)
def handle_remb(track : int, bitrate : int, pointer : int) -> None:
ctypes.cast(pointer, ctypes.POINTER(ctypes.c_uint)).contents.value = bitrate // 1000
@@ -229,14 +240,3 @@ def wire_remb(video_track : RtcVideoTrack, bitrate : ctypes.c_uint) -> None:
def clear_remb(rtc_peer : RtcPeer) -> None:
rtc_peer.get('sender_bitrate').value = 0
def get_payload_type(sdp_offer : SdpOffer, codec : AudioCodec | VideoCodec) -> int:
datachannel_library = datachannel_module.create_static_library()
payload_type_buffer = (ctypes.c_int * 16)()
payload_type_total = datachannel_library.rtcGetPayloadTypesForCodec(sdp_offer.encode(), codec.lower().encode(), payload_type_buffer, 16)
if payload_type_total:
return payload_type_buffer[0]
return 0
+4 -4
View File
@@ -10,6 +10,10 @@ def init_peers(session_id : SessionId) -> None:
RTC_STORE[session_id] = []
def has_peers(session_id : SessionId) -> bool:
return bool(RTC_STORE.get(session_id))
def get_peers(session_id : SessionId) -> List[RtcPeer]:
return RTC_STORE.get(session_id)
@@ -25,9 +29,5 @@ def delete_peers(session_id : SessionId) -> None:
return None
def has_peers(session_id : SessionId) -> bool:
return bool(RTC_STORE.get(session_id))
def clear() -> None:
RTC_STORE.clear()
+13 -13
View File
@@ -139,6 +139,19 @@ def test_delete_peers() -> None:
assert datachannel_library.rtcDeletePeerConnection(peer_connection) == -1
def test_get_payload_type() -> None:
peer_connection = create_peer_connection()
add_video_track(peer_connection, 'sendonly', 'vp8', 96)
add_audio_track(peer_connection, 'sendonly', 'opus', 111)
sdp_offer = create_sdp_offer(peer_connection)
assert get_payload_type(sdp_offer, 'vp8') == 96
assert get_payload_type(sdp_offer, 'opus') == 111
assert get_payload_type(sdp_offer, 'av1') == 0
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
@pytest.mark.parametrize('video_codec, payload_type', [ ('av1', 35), ('vp8', 96) ])
def test_wire_remb(video_codec : VideoCodec, payload_type : int) -> None:
datachannel_library = datachannel_module.create_static_library()
@@ -192,16 +205,3 @@ def test_wire_remb_receiver(video_codec : VideoCodec, payload_type : int) -> Non
assert rtc_peer.get('receiver_bitrate').value == 6000
datachannel_library.rtcDeletePeerConnection(peer_connection)
def test_get_payload_type() -> None:
peer_connection = create_peer_connection()
add_video_track(peer_connection, 'sendonly', 'vp8', 96)
add_audio_track(peer_connection, 'sendonly', 'opus', 111)
sdp_offer = create_sdp_offer(peer_connection)
assert get_payload_type(sdp_offer, 'vp8') == 96
assert get_payload_type(sdp_offer, 'opus') == 111
assert get_payload_type(sdp_offer, 'av1') == 0
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
+167 -167
View File
@@ -45,121 +45,93 @@ def before_each() -> None:
# TODO: refine test
@pytest.mark.parametrize('video_codec', ['av1', 'vp8'])
def test_decode_video_frame(video_codec: VideoCodec) -> None:
@pytest.mark.anyio
async def test_process_image() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
video_resolution = (vision_frame.shape[1], vision_frame.shape[0])
frame_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes()
websocket_mock = AsyncMock()
websocket_mock.receive.side_effect = [{'type': 'websocket.receive', 'bytes': frame_buffer}]
if video_codec == 'av1':
encode_frame_buffer = aom_encoder.encode(aom_encoder.create(video_resolution, 1000, 1, 0), frame_buffer, video_resolution, 0)
decode_frame_buffer = decode_video_frame(video_codec, aom_decoder.create(8), encode_frame_buffer).tobytes()
state_manager.init_item('source_paths', [get_test_example_file('source.jpg')])
await process_image(websocket_mock)
if is_linux() or is_windows():
assert create_hash(decode_frame_buffer) == '299b6ad6'
websocket_mock.send_bytes.assert_called_once()
assert websocket_mock.send_bytes.call_args[0][0][:3] == b'\xff\xd8\xff'
if is_macos():
assert create_hash(decode_frame_buffer) == '9f463b13'
state_manager.init_item('source_paths', None)
await process_image(websocket_mock)
assert decode_video_frame('av1', aom_decoder.create(8), bytes()) is None
if video_codec == 'vp8':
encode_frame_buffer = vpx_encoder.encode(vpx_encoder.create(video_resolution, 1000, 1, 0), frame_buffer, video_resolution, 0)
decode_frame_buffer = decode_video_frame(video_codec, vpx_decoder.create(8), encode_frame_buffer).tobytes()
if is_linux() or is_windows():
assert create_hash(decode_frame_buffer) == '99ef2c25'
if is_macos():
assert create_hash(decode_frame_buffer) == 'ff3ecb43'
assert decode_video_frame('vp8', vpx_decoder.create(8), bytes()) is None
websocket_mock.send_bytes.assert_called_once()
# TODO: refine test
def test_receive_video_frames() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.decode_video_frame', return_value = vision_frame):
receiver_thread = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue), daemon = True)
receiver_thread.start()
receiver_thread.join(timeout = 2.0)
if is_linux() or is_windows():
assert create_hash(video_queue.get_nowait().tobytes()) == 'a17439db'
if is_macos():
assert create_hash(video_queue.get_nowait().tobytes()) == '38d00e2a'
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
resolution = (vision_frame.shape[1], vision_frame.shape[0])
frame_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
video_encoder = create_video_encoder(video_codec, resolution, 4000)
@pytest.mark.parametrize('video_codec, session_id', [ ('av1', 'test-process-video-av1'), ('vp8', 'test-process-video-vp8') ])
def test_process_video(video_codec : VideoCodec, session_id : str) -> None:
peer_connection = rtc.create_peer_connection()
if video_codec == 'av1':
assert aom_encoder.encode(video_encoder, frame_buffer, resolution, 0)
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 35)
if video_codec == 'vp8':
assert vpx_encoder.encode(video_encoder, frame_buffer, resolution, 0)
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 96)
destroy_video_encoder(video_codec, video_encoder)
rtc.add_audio_track(peer_connection, 'sendrecv', 'opus', 111)
sdp_offer = rtc.create_sdp_offer(peer_connection)
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
if video_codec == 'av1':
assert not aom_encoder.encode(video_encoder, frame_buffer, resolution, 1)
with patch('facefusion.apis.stream_helper.threading.Thread'):
sdp_answer = process_video(session_id, sdp_offer)
if video_codec == 'vp8':
assert not vpx_encoder.encode(video_encoder, frame_buffer, resolution, 1)
assert sdp_answer
assert 'm=video' in sdp_answer
assert 'a=recvonly' in sdp_answer
assert 'a=sendonly' in sdp_answer
for peer in rtc_store.get_peers(session_id):
sender_bitrate = peer.get('sender_bitrate')
receiver_bitrate = peer.get('receiver_bitrate')
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
resolution = (vision_frame.shape[1], vision_frame.shape[0])
assert sender_bitrate.value == 0
assert receiver_bitrate.value == 0
video_encoder = create_video_encoder(video_codec, resolution, 4000)
rtc.handle_remb(0, 6000000, ctypes.addressof(sender_bitrate))
assert sender_bitrate.value == 6000
if video_codec == 'av1':
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 4000
if video_codec == 'vp8':
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 4000
assert update_video_encoder_bitrate(video_codec, video_encoder, 6000)
if video_codec == 'av1':
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 6000
if video_codec == 'vp8':
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 6000
destroy_video_encoder(video_codec, video_encoder)
rtc.handle_remb(0, 4000000, ctypes.addressof(receiver_bitrate))
assert receiver_bitrate.value == 4000
# TODO: refine test
def test_receive_audio_frames() -> None:
audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
@pytest.mark.anyio
async def test_receive_vision_frames() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes()
websocket_mock = AsyncMock()
websocket_mock.receive.side_effect =\
[
{
'type': 'websocket.receive',
'bytes': frame_buffer
},
{
'type': 'websocket.receive',
'bytes': b'invalid'
},
{
'type': 'websocket.receive',
'bytes': frame_buffer
},
{
'type': 'websocket.disconnect'
}
]
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.opus_decoder.decode', return_value = audio_frame.tobytes()):
receiver_thread = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue), daemon = True)
receiver_thread.start()
audio_frame = audio_queue.get(timeout = 2.0)
receiver_thread.join(timeout = 1.0)
frames = []
assert audio_frame.dtype == numpy.float32
assert audio_frame.size == 960 * 2
assert audio_queue.empty()
async for frame in receive_vision_frames(websocket_mock):
frames.append(frame)
assert len(frames) == 2
assert frames[0].shape == vision_frame.shape
# TODO: refine test
@@ -201,57 +173,121 @@ def test_run_peer_loop() -> None:
# TODO: refine test
@pytest.mark.anyio
async def test_receive_vision_frames() -> None:
def test_receive_video_frames() -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes()
websocket_mock = AsyncMock()
websocket_mock.receive.side_effect =\
[
{
'type': 'websocket.receive',
'bytes': frame_buffer
},
{
'type': 'websocket.receive',
'bytes': b'invalid'
},
{
'type': 'websocket.receive',
'bytes': frame_buffer
},
{
'type': 'websocket.disconnect'
}
]
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
video_queue : queue.Queue[VisionFrame] = queue.Queue(maxsize = 1)
frames = []
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.decode_video_frame', return_value = vision_frame):
receiver_thread = threading.Thread(target = receive_video_frames, args = (0, 'vp8', video_queue), daemon = True)
receiver_thread.start()
receiver_thread.join(timeout = 2.0)
async for frame in receive_vision_frames(websocket_mock):
frames.append(frame)
if is_linux() or is_windows():
assert create_hash(video_queue.get_nowait().tobytes()) == 'a17439db'
assert len(frames) == 2
assert frames[0].shape == vision_frame.shape
if is_macos():
assert create_hash(video_queue.get_nowait().tobytes()) == '38d00e2a'
# TODO: refine test
@pytest.mark.anyio
async def test_process_image() -> None:
def test_receive_audio_frames() -> None:
audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)
datachannel_library_mock = MagicMock()
datachannel_library_mock.rtcReceiveMessage.side_effect = [ 0, -1 ]
audio_queue : queue.Queue[AudioFrame] = queue.Queue(maxsize = 4)
with patch('facefusion.apis.stream_helper.datachannel_module.create_static_library', return_value = datachannel_library_mock), \
patch('facefusion.apis.stream_helper.opus_decoder.decode', return_value = audio_frame.tobytes()):
receiver_thread = threading.Thread(target = receive_audio_frames, args = (0, 'opus', audio_queue), daemon = True)
receiver_thread.start()
audio_frame = audio_queue.get(timeout = 2.0)
receiver_thread.join(timeout = 1.0)
assert audio_frame.dtype == numpy.float32
assert audio_frame.size == 960 * 2
assert audio_queue.empty()
# TODO: refine test
@pytest.mark.parametrize('video_codec', ['av1', 'vp8'])
def test_decode_video_frame(video_codec: VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
frame_buffer = cv2.imencode('.jpg', vision_frame)[1].tobytes()
websocket_mock = AsyncMock()
websocket_mock.receive.side_effect = [{'type': 'websocket.receive', 'bytes': frame_buffer}]
frame_resolution = (vision_frame.shape[1], vision_frame.shape[0])
input_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
state_manager.init_item('source_paths', [get_test_example_file('source.jpg')])
await process_image(websocket_mock)
if video_codec == 'av1':
encode_buffer = aom_encoder.encode(aom_encoder.create(frame_resolution, 1000, 1, 0), input_buffer, frame_resolution, 0)
decode_buffer = decode_video_frame(video_codec, aom_decoder.create(8), encode_buffer).tobytes()
websocket_mock.send_bytes.assert_called_once()
assert websocket_mock.send_bytes.call_args[0][0][:3] == b'\xff\xd8\xff'
if is_linux() or is_windows():
assert create_hash(decode_buffer) == '299b6ad6'
state_manager.init_item('source_paths', None)
await process_image(websocket_mock)
if is_macos():
assert create_hash(decode_buffer) == '9f463b13'
websocket_mock.send_bytes.assert_called_once()
assert decode_video_frame('av1', aom_decoder.create(8), bytes()) is None
if video_codec == 'vp8':
encode_buffer = vpx_encoder.encode(vpx_encoder.create(frame_resolution, 1000, 1, 0), input_buffer, frame_resolution, 0)
decode_buffer = decode_video_frame(video_codec, vpx_decoder.create(8), encode_buffer).tobytes()
if is_linux() or is_windows():
assert create_hash(decode_buffer) == '99ef2c25'
if is_macos():
assert create_hash(decode_buffer) == 'ff3ecb43'
assert decode_video_frame('vp8', vpx_decoder.create(8), bytes()) is None
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
frame_resolution = (vision_frame.shape[1], vision_frame.shape[0])
input_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
video_encoder = create_video_encoder(video_codec, frame_resolution, 4000)
if video_codec == 'av1':
assert aom_encoder.encode(video_encoder, input_buffer, frame_resolution, 0)
if video_codec == 'vp8':
assert vpx_encoder.encode(video_encoder, input_buffer, frame_resolution, 0)
destroy_video_encoder(video_codec, video_encoder)
if video_codec == 'av1':
assert not aom_encoder.encode(video_encoder, input_buffer, frame_resolution, 1)
if video_codec == 'vp8':
assert not vpx_encoder.encode(video_encoder, input_buffer, frame_resolution, 1)
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
frame_resolution = (vision_frame.shape[1], vision_frame.shape[0])
video_encoder = create_video_encoder(video_codec, frame_resolution, 4000)
if video_codec == 'av1':
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 4000
if video_codec == 'vp8':
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 4000
assert update_video_encoder_bitrate(video_codec, video_encoder, 6000)
if video_codec == 'av1':
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 6000
if video_codec == 'vp8':
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 6000
destroy_video_encoder(video_codec, video_encoder)
# TODO: refine test
@@ -275,39 +311,3 @@ async def test_websocket_stream() -> None:
websocket_mock.accept.assert_called_once()
websocket_mock.close.assert_called_once()
# TODO: refine test
@pytest.mark.parametrize('video_codec, session_id', [ ('av1', 'test-process-video-av1'), ('vp8', 'test-process-video-vp8') ])
def test_process_video(video_codec : VideoCodec, session_id : str) -> None:
peer_connection = rtc.create_peer_connection()
if video_codec == 'av1':
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 35)
if video_codec == 'vp8':
rtc.add_video_track(peer_connection, 'sendrecv', video_codec, 96)
rtc.add_audio_track(peer_connection, 'sendrecv', 'opus', 111)
sdp_offer = rtc.create_sdp_offer(peer_connection)
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
with patch('facefusion.apis.stream_helper.threading.Thread'):
sdp_answer = process_video(session_id, sdp_offer)
assert sdp_answer
assert 'm=video' in sdp_answer
assert 'a=recvonly' in sdp_answer
assert 'a=sendonly' in sdp_answer
for peer in rtc_store.get_peers(session_id):
sender_bitrate = peer.get('sender_bitrate')
receiver_bitrate = peer.get('receiver_bitrate')
assert sender_bitrate.value == 0
assert receiver_bitrate.value == 0
rtc.handle_remb(0, 6000000, ctypes.addressof(sender_bitrate))
assert sender_bitrate.value == 6000
rtc.handle_remb(0, 4000000, ctypes.addressof(receiver_bitrate))
assert receiver_bitrate.value == 4000