mirror of
https://github.com/facefusion/facefusion.git
synced 2026-06-02 19:01:35 +02:00
feat(rtc): REMB bitrate adaptation with in-place encoder update (#1130)
* improve test * fix lint * cleanup * cleanup
This commit is contained in:
@@ -152,11 +152,13 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None:
|
||||
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
|
||||
frame_index = 0
|
||||
|
||||
if peer_bitrate and peer_bitrate != temp_bitrate: # TODO avoid != in condition
|
||||
destroy_video_encoder(video_codec, video_encoder)
|
||||
if peer_bitrate and peer_bitrate - temp_bitrate:
|
||||
temp_bitrate = peer_bitrate
|
||||
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
|
||||
frame_index = 0
|
||||
|
||||
if not update_video_encoder_bitrate(video_codec, video_encoder, temp_bitrate):
|
||||
destroy_video_encoder(video_codec, video_encoder)
|
||||
video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate)
|
||||
frame_index = 0
|
||||
|
||||
output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index)
|
||||
|
||||
@@ -292,6 +294,16 @@ def destroy_video_decoder(video_codec : VideoCodec, video_decoder : VpxDecoder |
|
||||
vpx_decoder.destroy(video_decoder)
|
||||
|
||||
|
||||
def update_video_encoder_bitrate(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, bitrate : BitRate) -> bool:
|
||||
if video_codec == 'av1':
|
||||
return aom_encoder.update_bitrate(video_encoder, bitrate)
|
||||
|
||||
if video_codec == 'vp8':
|
||||
return vpx_encoder.update_bitrate(video_encoder, bitrate)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def destroy_video_encoder(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder) -> None:
|
||||
if video_codec == 'av1':
|
||||
aom_encoder.destroy(video_encoder)
|
||||
|
||||
@@ -10,7 +10,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
aom_encoder = ctypes.create_string_buffer(128)
|
||||
aom_encoder = ctypes.create_string_buffer(1152)
|
||||
aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_cx_algo')
|
||||
|
||||
config_buffer = ctypes.create_string_buffer(1024)
|
||||
@@ -28,6 +28,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
|
||||
aom_library.aom_codec_control(aom_encoder, 106, ctypes.c_int(1))
|
||||
aom_library.aom_codec_control(aom_encoder, 122, ctypes.c_int(0))
|
||||
aom_library.aom_codec_control(aom_encoder, 123, ctypes.c_int(0))
|
||||
ctypes.memmove(ctypes.addressof(aom_encoder) + 128, config_buffer, 1024)
|
||||
return aom_encoder
|
||||
|
||||
return None
|
||||
@@ -65,6 +66,16 @@ def collect(aom_encoder : AomEncoder) -> bytes:
|
||||
return bytes().join(output_parts)
|
||||
|
||||
|
||||
def update_bitrate(aom_encoder : AomEncoder, bitrate : BitRate) -> bool:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
if aom_library:
|
||||
struct.pack_into('I', aom_encoder, 128 + 136, bitrate)
|
||||
return aom_library.aom_codec_enc_config_set(aom_encoder, ctypes.cast(ctypes.addressof(aom_encoder) + 128, ctypes.c_void_p)) == 0
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def destroy(aom_encoder : AomEncoder) -> None:
|
||||
aom_library = aom_module.create_static_library()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
vpx_encoder = ctypes.create_string_buffer(64)
|
||||
vpx_encoder = ctypes.create_string_buffer(640)
|
||||
vp8_codec = ctypes.c_void_p.in_dll(vpx_library, 'vpx_codec_vp8_cx_algo')
|
||||
|
||||
config_buffer = ctypes.create_string_buffer(512)
|
||||
@@ -32,6 +32,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int,
|
||||
vpx_library.vpx_codec_control_(vpx_encoder, 13, ctypes.c_int(cpu_count))
|
||||
vpx_library.vpx_codec_control_(vpx_encoder, 12, ctypes.c_int(3))
|
||||
vpx_library.vpx_codec_control_(vpx_encoder, 27, ctypes.c_int(10))
|
||||
ctypes.memmove(ctypes.addressof(vpx_encoder) + 64, config_buffer, 512)
|
||||
return vpx_encoder
|
||||
|
||||
return None
|
||||
@@ -69,6 +70,16 @@ def collect(vpx_encoder : VpxEncoder) -> bytes:
|
||||
return bytes().join(output_parts)
|
||||
|
||||
|
||||
def update_bitrate(vpx_encoder : VpxEncoder, bitrate : BitRate) -> bool:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
if vpx_library:
|
||||
struct.pack_into('I', vpx_encoder, 64 + 112, bitrate)
|
||||
return vpx_library.vpx_codec_enc_config_set(vpx_encoder, ctypes.cast(ctypes.addressof(vpx_encoder) + 64, ctypes.c_void_p)) == 0
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def destroy(vpx_encoder : VpxEncoder) -> None:
|
||||
vpx_library = vpx_module.create_static_library()
|
||||
|
||||
|
||||
@@ -110,6 +110,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.aom_codec_get_cx_data.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ]
|
||||
library.aom_codec_get_cx_data.restype = ctypes.c_void_p
|
||||
|
||||
library.aom_codec_enc_config_set.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ]
|
||||
library.aom_codec_enc_config_set.restype = ctypes.c_int
|
||||
|
||||
library.aom_codec_destroy.argtypes = [ ctypes.c_void_p ]
|
||||
library.aom_codec_destroy.restype = ctypes.c_int
|
||||
|
||||
|
||||
@@ -110,6 +110,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
|
||||
library.vpx_codec_get_cx_data.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ]
|
||||
library.vpx_codec_get_cx_data.restype = ctypes.c_void_p
|
||||
|
||||
library.vpx_codec_enc_config_set.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ]
|
||||
library.vpx_codec_enc_config_set.restype = ctypes.c_int
|
||||
|
||||
library.vpx_codec_destroy.argtypes = [ ctypes.c_void_p ]
|
||||
library.vpx_codec_destroy.restype = ctypes.c_int
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import ctypes
|
||||
import tempfile
|
||||
from typing import Iterator
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import ctypes
|
||||
import queue
|
||||
import struct
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -11,7 +12,7 @@ from tests.assert_helper import get_test_example_file, get_test_examples_directo
|
||||
|
||||
from facefusion import rtc, rtc_store, state_manager
|
||||
from facefusion.apis.endpoints.stream import websocket_stream
|
||||
from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop
|
||||
from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop, update_video_encoder_bitrate
|
||||
from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder
|
||||
from facefusion.common_helper import is_linux, is_macos, is_windows
|
||||
from facefusion.download import conditional_download
|
||||
@@ -118,6 +119,30 @@ def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None:
|
||||
assert not vpx_encoder.encode(video_encoder, frame_buffer, resolution, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ])
|
||||
def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None:
|
||||
vision_frame = read_video_frame(get_test_example_file('target-240p.mp4'))
|
||||
resolution = (vision_frame.shape[1], vision_frame.shape[0])
|
||||
|
||||
video_encoder = create_video_encoder(video_codec, resolution, 4000)
|
||||
|
||||
if video_codec == 'av1':
|
||||
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 4000
|
||||
|
||||
if video_codec == 'vp8':
|
||||
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 4000
|
||||
|
||||
assert update_video_encoder_bitrate(video_codec, video_encoder, 6000)
|
||||
|
||||
if video_codec == 'av1':
|
||||
assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 6000
|
||||
|
||||
if video_codec == 'vp8':
|
||||
assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 6000
|
||||
|
||||
destroy_video_encoder(video_codec, video_encoder)
|
||||
|
||||
|
||||
# TODO: refine test
|
||||
def test_receive_audio_frames() -> None:
|
||||
audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)
|
||||
|
||||
Reference in New Issue
Block a user