feat(rtc): REMB bitrate adaptation with in-place encoder update (#1130)

* improve test

* fix lint

* cleanup

* cleanup
This commit is contained in:
Harisreedhar
2026-05-29 21:15:03 +05:30
committed by GitHub
parent 871559cb6a
commit 6b9ddd9a4f
7 changed files with 72 additions and 8 deletions
+16 -4
View File
@@ -152,11 +152,13 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
frame_index = 0
if peer_bitrate and peer_bitrate != temp_bitrate: # TODO avoid != in condition
destroy_video_encoder(video_codec, video_encoder)
if peer_bitrate and peer_bitrate - temp_bitrate:
temp_bitrate = peer_bitrate
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
frame_index = 0
if not update_video_encoder_bitrate(video_codec, video_encoder, temp_bitrate):
destroy_video_encoder(video_codec, video_encoder)
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
frame_index = 0
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
@@ -292,6 +294,16 @@ def destroy_video_decoder(video_codec : VideoCodec, video_decoder : VpxDecoder |
vpx_decoder.destroy(video_decoder)
def update_video_encoder_bitrate(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, bitrate : BitRate) -> bool:
if video_codec == 'av1':
return aom_encoder.update_bitrate(video_encoder, bitrate)
if video_codec == 'vp8':
return vpx_encoder.update_bitrate(video_encoder, bitrate)
return False
def destroy_video_encoder(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder) -> None:
if video_codec == 'av1':
aom_encoder.destroy(video_encoder)
+12 -1
View File
@@ -10,7 +10,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
aom_library = aom_module.create_static_library()
if aom_library:
aom_encoder = ctypes.create_string_buffer(128)
aom_encoder = ctypes.create_string_buffer(1152)
aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_cx_algo')
config_buffer = ctypes.create_string_buffer(1024)
@@ -28,6 +28,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
aom_library.aom_codec_control(aom_encoder, 106, ctypes.c_int(1))
aom_library.aom_codec_control(aom_encoder, 122, ctypes.c_int(0))
aom_library.aom_codec_control(aom_encoder, 123, ctypes.c_int(0))
ctypes.memmove(ctypes.addressof(aom_encoder) + 128, config_buffer, 1024)
return aom_encoder
return None
@@ -65,6 +66,16 @@ def collect(aom_encoder : AomEncoder) -> bytes:
return bytes().join(output_parts)
def update_bitrate(aom_encoder : AomEncoder, bitrate : BitRate) -> bool:
aom_library = aom_module.create_static_library()
if aom_library:
struct.pack_into('I', aom_encoder, 128 + 136, bitrate)
return aom_library.aom_codec_enc_config_set(aom_encoder, ctypes.cast(ctypes.addressof(aom_encoder) + 128, ctypes.c_void_p)) == 0
return False
def destroy(aom_encoder : AomEncoder) -> None:
aom_library = aom_module.create_static_library()
+12 -1
View File
@@ -10,7 +10,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
vpx_library = vpx_module.create_static_library()
if vpx_library:
vpx_encoder = ctypes.create_string_buffer(64)
vpx_encoder = ctypes.create_string_buffer(640)
vp8_codec = ctypes.c_void_p.in_dll(vpx_library, 'vpx_codec_vp8_cx_algo')
config_buffer = ctypes.create_string_buffer(512)
@@ -32,6 +32,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
vpx_library.vpx_codec_control_(vpx_encoder, 13, ctypes.c_int(cpu_count))
vpx_library.vpx_codec_control_(vpx_encoder, 12, ctypes.c_int(3))
vpx_library.vpx_codec_control_(vpx_encoder, 27, ctypes.c_int(10))
ctypes.memmove(ctypes.addressof(vpx_encoder) + 64, config_buffer, 512)
return vpx_encoder
return None
@@ -69,6 +70,16 @@ def collect(vpx_encoder : VpxEncoder) -> bytes:
return bytes().join(output_parts)
def update_bitrate(vpx_encoder : VpxEncoder, bitrate : BitRate) -> bool:
vpx_library = vpx_module.create_static_library()
if vpx_library:
struct.pack_into('I', vpx_encoder, 64 + 112, bitrate)
return vpx_library.vpx_codec_enc_config_set(vpx_encoder, ctypes.cast(ctypes.addressof(vpx_encoder) + 64, ctypes.c_void_p)) == 0
return False
def destroy(vpx_encoder : VpxEncoder) -> None:
vpx_library = vpx_module.create_static_library()
+3
View File
@@ -110,6 +110,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
library.aom_codec_get_cx_data.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ]
library.aom_codec_get_cx_data.restype = ctypes.c_void_p
library.aom_codec_enc_config_set.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ]
library.aom_codec_enc_config_set.restype = ctypes.c_int
library.aom_codec_destroy.argtypes = [ ctypes.c_void_p ]
library.aom_codec_destroy.restype = ctypes.c_int
+3
View File
@@ -110,6 +110,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
library.vpx_codec_get_cx_data.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ]
library.vpx_codec_get_cx_data.restype = ctypes.c_void_p
library.vpx_codec_enc_config_set.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ]
library.vpx_codec_enc_config_set.restype = ctypes.c_int
library.vpx_codec_destroy.argtypes = [ ctypes.c_void_p ]
library.vpx_codec_destroy.restype = ctypes.c_int
-1
View File
@@ -1,4 +1,3 @@
import ctypes
import tempfile
from typing import Iterator
from unittest.mock import patch
+26 -1
View File
@@ -1,5 +1,6 @@
import ctypes
import queue
import struct
import threading
from unittest.mock import AsyncMock, MagicMock, patch
@@ -11,7 +12,7 @@ from tests.assert_helper import get_test_example_file, get_test_examples_directo
from facefusion import rtc, rtc_store, state_manager
from facefusion.apis.endpoints.stream import websocket_stream
from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop
from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop, update_video_encoder_bitrate
from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder
from facefusion.common_helper import is_linux, is_macos, is_windows
from facefusion.download import conditional_download
@@ -118,6 +119,30 @@ def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None:
assert not vpx_encoder.encode(video_encoder, frame_buffer, resolution, 1)
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None:
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
resolution = (vision_frame.shape[1], vision_frame.shape[0])
video_encoder = create_video_encoder(video_codec, resolution, 4000)
if video_codec == 'av1':
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 4000
if video_codec == 'vp8':
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 4000
assert update_video_encoder_bitrate(video_codec, video_encoder, 6000)
if video_codec == 'av1':
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 6000
if video_codec == 'vp8':
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 6000
destroy_video_encoder(video_codec, video_encoder)
# TODO: refine test
def test_receive_audio_frames() -> None:
audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)