diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 797c43c0..70a66be5 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -152,11 +152,13 @@ def run_peer_loop(session_id : SessionId, rtc_peer : RtcPeer) -> None: video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate) frame_index = 0 - if peer_bitrate and peer_bitrate != temp_bitrate: # TODO avoid != in condition - destroy_video_encoder(video_codec, video_encoder) + if peer_bitrate and peer_bitrate - temp_bitrate: temp_bitrate = peer_bitrate - video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate) - frame_index = 0 + + if not update_video_encoder_bitrate(video_codec, video_encoder, temp_bitrate): + destroy_video_encoder(video_codec, video_encoder) + video_encoder = create_video_encoder(video_codec, temp_resolution, temp_bitrate) + frame_index = 0 output_video_buffer = encode_video_frame(video_codec, video_encoder, output_vision_buffer, temp_resolution, frame_index) @@ -292,6 +294,16 @@ def destroy_video_decoder(video_codec : VideoCodec, video_decoder : VpxDecoder | vpx_decoder.destroy(video_decoder) +def update_video_encoder_bitrate(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder, bitrate : BitRate) -> bool: + if video_codec == 'av1': + return aom_encoder.update_bitrate(video_encoder, bitrate) + + if video_codec == 'vp8': + return vpx_encoder.update_bitrate(video_encoder, bitrate) + + return False + + def destroy_video_encoder(video_codec : VideoCodec, video_encoder : VpxEncoder | AomEncoder) -> None: if video_codec == 'av1': aom_encoder.destroy(video_encoder) diff --git a/facefusion/codecs/aom_encoder.py b/facefusion/codecs/aom_encoder.py index d3f7eefe..cdc6cfd7 100644 --- a/facefusion/codecs/aom_encoder.py +++ b/facefusion/codecs/aom_encoder.py @@ -10,7 +10,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, aom_library = aom_module.create_static_library() if aom_library: - aom_encoder = ctypes.create_string_buffer(128) + aom_encoder = ctypes.create_string_buffer(1152) aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_cx_algo') config_buffer = ctypes.create_string_buffer(1024) @@ -28,6 +28,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, aom_library.aom_codec_control(aom_encoder, 106, ctypes.c_int(1)) aom_library.aom_codec_control(aom_encoder, 122, ctypes.c_int(0)) aom_library.aom_codec_control(aom_encoder, 123, ctypes.c_int(0)) + ctypes.memmove(ctypes.addressof(aom_encoder) + 128, config_buffer, 1024) return aom_encoder return None @@ -65,6 +66,16 @@ def collect(aom_encoder : AomEncoder) -> bytes: return bytes().join(output_parts) +def update_bitrate(aom_encoder : AomEncoder, bitrate : BitRate) -> bool: + aom_library = aom_module.create_static_library() + + if aom_library: + struct.pack_into('I', aom_encoder, 128 + 136, bitrate) + return aom_library.aom_codec_enc_config_set(aom_encoder, ctypes.cast(ctypes.addressof(aom_encoder) + 128, ctypes.c_void_p)) == 0 + + return False + + def destroy(aom_encoder : AomEncoder) -> None: aom_library = aom_module.create_static_library() diff --git a/facefusion/codecs/vpx_encoder.py b/facefusion/codecs/vpx_encoder.py index c03cccf5..45977da0 100644 --- a/facefusion/codecs/vpx_encoder.py +++ b/facefusion/codecs/vpx_encoder.py @@ -10,7 +10,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, vpx_library = vpx_module.create_static_library() if vpx_library: - vpx_encoder = ctypes.create_string_buffer(64) + vpx_encoder = ctypes.create_string_buffer(640) vp8_codec = ctypes.c_void_p.in_dll(vpx_library, 'vpx_codec_vp8_cx_algo') config_buffer = ctypes.create_string_buffer(512) @@ -32,6 +32,7 @@ def create(frame_resolution : Resolution, bitrate : BitRate, thread_count : int, vpx_library.vpx_codec_control_(vpx_encoder, 13, ctypes.c_int(cpu_count)) vpx_library.vpx_codec_control_(vpx_encoder, 12, ctypes.c_int(3)) vpx_library.vpx_codec_control_(vpx_encoder, 27, ctypes.c_int(10)) + ctypes.memmove(ctypes.addressof(vpx_encoder) + 64, config_buffer, 512) return vpx_encoder return None @@ -69,6 +70,16 @@ def collect(vpx_encoder : VpxEncoder) -> bytes: return bytes().join(output_parts) +def update_bitrate(vpx_encoder : VpxEncoder, bitrate : BitRate) -> bool: + vpx_library = vpx_module.create_static_library() + + if vpx_library: + struct.pack_into('I', vpx_encoder, 64 + 112, bitrate) + return vpx_library.vpx_codec_enc_config_set(vpx_encoder, ctypes.cast(ctypes.addressof(vpx_encoder) + 64, ctypes.c_void_p)) == 0 + + return False + + def destroy(vpx_encoder : VpxEncoder) -> None: vpx_library = vpx_module.create_static_library() diff --git a/facefusion/libraries/aom.py b/facefusion/libraries/aom.py index e176b88e..663ee7c2 100644 --- a/facefusion/libraries/aom.py +++ b/facefusion/libraries/aom.py @@ -110,6 +110,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.aom_codec_get_cx_data.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ] library.aom_codec_get_cx_data.restype = ctypes.c_void_p + library.aom_codec_enc_config_set.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ] + library.aom_codec_enc_config_set.restype = ctypes.c_int + library.aom_codec_destroy.argtypes = [ ctypes.c_void_p ] library.aom_codec_destroy.restype = ctypes.c_int diff --git a/facefusion/libraries/vpx.py b/facefusion/libraries/vpx.py index c5a91fa4..b128643d 100644 --- a/facefusion/libraries/vpx.py +++ b/facefusion/libraries/vpx.py @@ -110,6 +110,9 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.vpx_codec_get_cx_data.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p) ] library.vpx_codec_get_cx_data.restype = ctypes.c_void_p + library.vpx_codec_enc_config_set.argtypes = [ ctypes.c_void_p, ctypes.c_void_p ] + library.vpx_codec_enc_config_set.restype = ctypes.c_int + library.vpx_codec_destroy.argtypes = [ ctypes.c_void_p ] library.vpx_codec_destroy.restype = ctypes.c_int diff --git a/tests/test_api_stream.py b/tests/test_api_stream.py index 9420dbf2..ac4784d1 100644 --- a/tests/test_api_stream.py +++ b/tests/test_api_stream.py @@ -1,4 +1,3 @@ -import ctypes import tempfile from typing import Iterator from unittest.mock import patch diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py index 31ab4ce9..14e01323 100644 --- a/tests/test_stream_helper.py +++ b/tests/test_stream_helper.py @@ -1,5 +1,6 @@ import ctypes import queue +import struct import threading from unittest.mock import AsyncMock, MagicMock, patch @@ -11,7 +12,7 @@ from tests.assert_helper import get_test_example_file, get_test_examples_directo from facefusion import rtc, rtc_store, state_manager from facefusion.apis.endpoints.stream import websocket_stream -from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop +from facefusion.apis.stream_helper import create_video_encoder, decode_video_frame, destroy_video_encoder, process_image, process_video, receive_audio_frames, receive_video_frames, receive_vision_frames, run_peer_loop, update_video_encoder_bitrate from facefusion.codecs import aom_decoder, aom_encoder, vpx_decoder, vpx_encoder from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.download import conditional_download @@ -118,6 +119,30 @@ def test_create_and_destroy_video_encoder(video_codec : VideoCodec) -> None: assert not vpx_encoder.encode(video_encoder, frame_buffer, resolution, 1) +@pytest.mark.parametrize('video_codec', [ 'av1', 'vp8' ]) +def test_update_video_encoder_bitrate(video_codec : VideoCodec) -> None: + vision_frame = read_video_frame(get_test_example_file('target-240p.mp4')) + resolution = (vision_frame.shape[1], vision_frame.shape[0]) + + video_encoder = create_video_encoder(video_codec, resolution, 4000) + + if video_codec == 'av1': + assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 4000 + + if video_codec == 'vp8': + assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 4000 + + assert update_video_encoder_bitrate(video_codec, video_encoder, 6000) + + if video_codec == 'av1': + assert struct.unpack_from('I', video_encoder, 128 + 136)[0] == 6000 + + if video_codec == 'vp8': + assert struct.unpack_from('I', video_encoder, 64 + 112)[0] == 6000 + + destroy_video_encoder(video_codec, video_encoder) + + # TODO: refine test def test_receive_audio_frames() -> None: audio_frame = numpy.zeros(960 * 2, dtype = numpy.float32)