Implement tier 1 REMB (#1129)

* implement tier 1

* fix lint

* cleanup

* cleanup

* fix lint

* use clear_remb method

* use single callback

* improve test

* improve test

* improve test

* improve test

* improve test
This commit is contained in:
Harisreedhar
2026-05-29 19:00:27 +05:30
committed by GitHub
parent cc0af9175a
commit 2a8672b54d
7 changed files with 117 additions and 19 deletions
+24 -12
View File
@@ -14,7 +14,7 @@ from facefusion import rtc, rtc_store, state_manager, streamer
from facefusion.audio import create_empty_audio_frame
from facefusion.codecs import aom_decoder, aom_encoder, opus_decoder, opus_encoder, vpx_decoder, vpx_encoder
from facefusion.libraries import datachannel as datachannel_module
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
from facefusion.types import AomDecoder, AomEncoder, AudioCodec, AudioFrame, BitRate, PeerConnection, Resolution, RtcPeer, SdpAnswer, SdpOffer, SessionId, VideoCodec, VisionFrame, VpxDecoder, VpxEncoder
async def process_image(websocket : WebSocket) -> None:
@@ -44,6 +44,8 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA
peer_connection : PeerConnection = rtc.create_peer_connection()
video_receiver_track = rtc.add_video_track(peer_connection, 'recvonly', video_codec, video_payload_type)
video_sender_track = rtc.add_video_track(peer_connection, 'sendonly', video_codec, video_payload_type)
bitrate = ctypes.c_uint(0)
rtc.wire_remb(video_sender_track, bitrate)
audio_codec : AudioCodec = 'opus'
audio_payload_type = rtc.get_payload_type(sdp_offer, audio_codec)
@@ -66,7 +68,8 @@ def process_video(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpA
'sender_track': video_sender_track,
'receiver_track': video_receiver_track,
'codec': video_codec
}
},
'bitrate': bitrate
}
if audio_receiver_track and audio_sender_track:
@@ -126,7 +129,8 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
if numpy.any(temp_vision_frame):
audio_frame = create_empty_audio_frame()
temp_resolution : Resolution = (temp_vision_frame.shape[1], temp_vision_frame.shape[0])
video_encoder = create_video_encoder(video_codec, temp_resolution)
temp_bitrate : BitRate = 8000
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
audio_encoder = opus_encoder.create(48000, 2)
frame_index = 0
@@ -140,14 +144,21 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
send_timestamp = time.monotonic()
if output_resolution == temp_resolution:
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
else:
destroy_video_encoder(video_codec, video_encoder) # TODO: remove unconditional destroy methods, which have no impact on control flow
peer_bitrate = rtc_peer.get('bitrate').value
if output_resolution != temp_resolution: # TODO avoid != in condition
destroy_video_encoder(video_codec, video_encoder)
temp_resolution = output_resolution
video_encoder = create_video_encoder(video_codec, temp_resolution)
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
frame_index = 0
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
if peer_bitrate and peer_bitrate != temp_bitrate: # TODO avoid != in condition
destroy_video_encoder(video_codec, video_encoder)
temp_bitrate = peer_bitrate
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
frame_index = 0
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
if output_video_buffer:
rtc.send_video(rtc_peer, output_video_buffer, int(send_timestamp * 90000))
@@ -163,6 +174,7 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
destroy_video_encoder(video_codec, video_encoder) # TODO: remove unconditional destroy methods, which have no impact on control flow
opus_encoder.destroy(audio_encoder)
rtc.clear_remb(rtc_peer)
for receiver_thread in receiver_threads:
receiver_thread.join()
@@ -262,12 +274,12 @@ def create_video_decoder(video_codec : VideoCodec) -> Optional[VpxDecoder | AomD
return None
def create_video_encoder(video_codec : VideoCodec, resolution : Resolution) -> Optional[VpxEncoder | AomEncoder]:
def create_video_encoder(video_codec : VideoCodec, resolution : Resolution, bitrate : BitRate) -> Optional[VpxEncoder | AomEncoder]:
if video_codec == 'av1':
return aom_encoder.create(resolution, 8000, 8, 10)
return aom_encoder.create(resolution, bitrate, 8, 10)
if video_codec == 'vp8':
return vpx_encoder.create(resolution, 8000, 8, 10)
return vpx_encoder.create(resolution, bitrate, 8, 10)
return None
+6
View File
@@ -218,6 +218,12 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
library.rtcReceiveMessage.argtypes = [ ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(ctypes.c_int) ]
library.rtcReceiveMessage.restype = ctypes.c_int
library.rtcSetUserPointer.argtypes = [ ctypes.c_int, ctypes.c_void_p ]
library.rtcSetUserPointer.restype = None
library.rtcChainRembHandler.argtypes = [ ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_uint, ctypes.c_void_p) ]
library.rtcChainRembHandler.restype = ctypes.c_int
return library
+15
View File
@@ -216,6 +216,21 @@ def create_video_track_init(media_direction : MediaDirection, video_codec : Vide
return ctypes.byref(track_init)
@ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_uint, ctypes.c_void_p)
def handle_remb(track : int, bitrate : int, pointer : int) -> None:
ctypes.cast(pointer, ctypes.POINTER(ctypes.c_uint)).contents.value = bitrate // 1000
def wire_remb(video_track : RtcVideoTrack, bitrate : ctypes.c_uint) -> None:
datachannel_library = datachannel_module.create_static_library()
datachannel_library.rtcSetUserPointer(video_track, ctypes.cast(ctypes.byref(bitrate), ctypes.c_void_p))
datachannel_library.rtcChainRembHandler(video_track, handle_remb)
def clear_remb(rtc_peer : RtcPeer) -> None:
rtc_peer.get('bitrate').value = 0
def get_payload_type(sdp_offer : SdpOffer, codec : AudioCodec | VideoCodec) -> int:
datachannel_library = datachannel_module.create_static_library()
payload_type_buffer = (ctypes.c_int * 16)()
+1
View File
@@ -314,6 +314,7 @@ RtcPeer = TypedDict('RtcPeer',
'peer_connection': PeerConnection,
'audio': NotRequired[RtcPeerAudio],
'video': RtcPeerVideo,
'bitrate': ctypes.c_uint
})
RtcStore : TypeAlias = Dict[SessionId, List[RtcPeer]]
+9
View File
@@ -1,3 +1,4 @@
import ctypes
import tempfile
from typing import Iterator
from unittest.mock import patch
@@ -135,6 +136,14 @@ def test_stream_video(test_client : TestClient, video_codec : VideoCodec) -> Non
assert stream_response.status_code == 201
assert 'm=video' in stream_response.text
session_id = session_manager.find_session_id(access_token)
for peer in rtc_store.get_peers(session_id):
bitrate = peer.get('bitrate')
assert bitrate.value == 0
rtc.handle_remb(0, 6000000, ctypes.addressof(bitrate))
assert bitrate.value == 6000
def test_delete_stream_video(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
+34 -5
View File
@@ -1,11 +1,12 @@
import ctypes
from typing import List
import pytest
from facefusion import state_manager
from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, get_payload_type, send_audio, send_video, set_remote_description
from facefusion.types import RtcPeer
from facefusion.rtc import add_audio_track, add_video_track, create_peer_connection, create_sdp_answer, create_sdp_offer, delete_peers, get_payload_type, handle_remb, send_audio, send_video, set_remote_description, wire_remb
from facefusion.types import RtcPeer, VideoCodec
@pytest.fixture(scope = 'module', autouse = True)
@@ -77,7 +78,8 @@ def test_send_video() -> None:
'sender_track': video_track,
'receiver_track': video_track,
'codec': 'vp8'
}
},
'bitrate': ctypes.c_uint(0)
}
send_video(rtc_peer, bytes(1024), 0)
@@ -103,7 +105,8 @@ def test_send_audio() -> None:
'sender_track': audio_track,
'receiver_track': audio_track,
'codec': 'opus'
}
},
'bitrate': ctypes.c_uint(0)
}
send_audio(rtc_peer, bytes(960), 0)
@@ -123,7 +126,8 @@ def test_delete_peers() -> None:
'sender_track': 0,
'receiver_track': 0,
'codec': 'vp8'
}
},
'bitrate': ctypes.c_uint(0)
}
]
@@ -132,6 +136,31 @@ def test_delete_peers() -> None:
assert datachannel_library.rtcDeletePeerConnection(peer_connection) == -1
@pytest.mark.parametrize('video_codec, payload_type', [ ('av1', 35), ('vp8', 96) ])
def test_wire_remb(video_codec : VideoCodec, payload_type : int) -> None:
datachannel_library = datachannel_module.create_static_library()
peer_connection = create_peer_connection()
video_sender_track = add_video_track(peer_connection, 'sendonly', video_codec, payload_type)
rtc_peer : RtcPeer =\
{
'peer_connection': peer_connection,
'video':
{
'sender_track': video_sender_track,
'receiver_track': video_sender_track,
'codec': video_codec
},
'bitrate': ctypes.c_uint(0)
}
wire_remb(video_sender_track, rtc_peer.get('bitrate'))
handle_remb(0, 6000000, ctypes.addressof(rtc_peer.get('bitrate')))
assert rtc_peer.get('bitrate').value == 6000
datachannel_library.rtcDeletePeerConnection(peer_connection)
def test_get_payload_type() -> None:
peer_connection = create_peer_connection()
add_video_track(peer_connection, 'sendonly', 'vp8', 96)
+28 -2
View File
@@ -1,3 +1,4 @@
import ctypes
import queue
import threading
from unittest.mock import AsyncMock, MagicMock, patch
@@ -10,7 +11,7 @@ from tests.assert_helper import get_test_example_file, get_test_examples_directo
from facefusion import rtc, rtc_store, state_manager
from facefusion.apis.endpoints.stream import websocket_stream
from facefusion.apis.stream_helper import decode_video_frame, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop
from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop
from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder
from facefusion.common_helper import is_linux, is_macos, is_windows
from facefusion.download import conditional_download
@@ -94,6 +95,30 @@ def test_receive_video_frames() -> None:
assert create_hash(video_queue.get_nowait().tobytes()) == '38d00e2a'
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
resolution = (vision_frame.shape[1], vision_frame.shape[0])
frame_buffer = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2YUV_I420).tobytes()
encoder = create_video_encoder(video_codec, resolution, 4000)
assert encoder is not None
if video_codec == 'av1':
assert aom_encoder.encode(encoder, frame_buffer, resolution, 0)
if video_codec == 'vp8':
assert vpx_encoder.encode(encoder, frame_buffer, resolution, 0)
destroy_video_encoder(video_codec, encoder)
if video_codec == 'av1':
assert not aom_encoder.encode(encoder, frame_buffer, resolution, 1)
if video_codec == 'vp8':
assert not vpx_encoder.encode(encoder, frame_buffer, resolution, 1)
# TODO: refine test
def test_receive_audio_frames() -> None:
audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)
@@ -128,7 +153,8 @@ def test_run_peer_loop() -> None:
'sender_track': video_sender_track,
'receiver_track': video_receiver_track,
'codec': 'vp8'
}
},
'bitrate': ctypes.c_uint(0)
}
session_id = 'test-run-peer-loop'