From 18a487347a5a22a78d86253c6601dc5071f39106 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Thu, 14 May 2026 16:11:23 +0200 Subject: [PATCH] av1 support integrated (#1112) --- facefusion/apis/stream_helper.py | 51 ++++++++++++++++++++++++--- facefusion/codecs/aom.py | 53 +++++++++++++++++++++++++---- facefusion/core.py | 3 +- facefusion/libraries/datachannel.py | 4 ++- facefusion/rtc.py | 41 +++++++++++++++++----- facefusion/rtc_store.py | 29 +++++++++++----- facefusion/types.py | 11 ++++-- tests/test_codec_aom.py | 2 +- tests/test_rtc.py | 12 +++---- 9 files changed, 164 insertions(+), 42 deletions(-) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 04a6337f..f45fc95b 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,7 +1,7 @@ import asyncio from collections import deque from collections.abc import AsyncIterator -from typing import Tuple +from typing import Tuple, cast, get_args import cv2 import numpy @@ -10,10 +10,11 @@ from starlette.websockets import WebSocket, WebSocketState from facefusion import rtc_store, session_context, session_manager, state_manager from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token +from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encode_aom_buffer, extract_aom_obus from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer from facefusion.streamer import process_vision_frame -from facefusion.types import Resolution, SessionId, VisionFrame +from facefusion.types import Resolution, SessionId, VideoCodec, VisionFrame async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]: @@ -41,8 +42,39 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr websocket_event = await websocket.receive() -# TODO: move to facefusion/vpx_encoder.py, throttle loop to avoid spinning on same frame -def run_video_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None: +def run_aom_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None: + aom_encoder = create_aom_encoder(initial_resolution, 4500, 8, 10) + current_resolution = initial_resolution + pts = 0 + + while vision_frame_deque: + vision_frame = vision_frame_deque[-1] + output_frame = process_vision_frame(vision_frame) + frame_resolution = (output_frame.shape[1], output_frame.shape[0]) + + if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]: + if aom_encoder: + destroy_aom_encoder(aom_encoder) + + current_resolution = frame_resolution + aom_encoder = create_aom_encoder(current_resolution, 4500, 8, 10) + pts = 0 + + if aom_encoder: + yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420) + frame_buffer = encode_aom_buffer(aom_encoder, yuv_frame.tobytes(), frame_resolution, pts) + + if frame_buffer: + for obu_buffer in extract_aom_obus(frame_buffer): + rtc_store.send_rtc_video(session_id, obu_buffer) + + pts += 1 + + if aom_encoder: + destroy_aom_encoder(aom_encoder) + + +def run_vp8_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None: vpx_encoder = create_vpx_encoder(initial_resolution, 4500, 8, 16) current_resolution = initial_resolution pts = 0 @@ -102,6 +134,10 @@ async def handle_video_stream(websocket : WebSocket) -> None: access_token = extract_access_token(websocket.scope) session_id = session_manager.find_session_id(access_token) session_context.set_session_id(session_id) + stream_codec : VideoCodec = 'av1' + + if websocket.query_params.get('codec') in get_args(VideoCodec): + stream_codec = cast(VideoCodec, websocket.query_params.get('codec')) await websocket.accept(subprotocol = subprotocol) @@ -127,7 +163,12 @@ async def handle_video_stream(websocket : WebSocket) -> None: rtc_store.create_rtc_stream(session_id) event_loop = asyncio.get_running_loop() - video_encode_task = event_loop.run_in_executor(None, run_video_encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval) + encode_loop = run_aom_encode_loop + + if stream_codec == 'vp8': + encode_loop = run_vp8_encode_loop + + video_encode_task = event_loop.run_in_executor(None, encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval) await websocket.send_text('ready') async for frame_type, frame_buffer in stream_frames: diff --git a/facefusion/codecs/aom.py b/facefusion/codecs/aom.py index e22b48ef..5164b4af 100644 --- a/facefusion/codecs/aom.py +++ b/facefusion/codecs/aom.py @@ -1,6 +1,6 @@ import ctypes import struct -from typing import Optional +from typing import List, Optional from facefusion.libraries import aom as aom_module from facefusion.types import AomEncoder, BitRate, Resolution @@ -10,16 +10,17 @@ def create_aom_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_ aom_library = aom_module.create_static_library() if aom_library: - aom_encoder = ctypes.create_string_buffer(1024) + aom_encoder = ctypes.create_string_buffer(128) aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_cx_algo') - config_buffer = ctypes.create_string_buffer(4096) + config_buffer = ctypes.create_string_buffer(1024) if aom_library.aom_codec_enc_config_default(ctypes.byref(aom_codec), config_buffer, 1) == 0: struct.pack_into('I', config_buffer, 4, thread_count) struct.pack_into('I', config_buffer, 12, frame_resolution[0]) struct.pack_into('I', config_buffer, 16, frame_resolution[1]) struct.pack_into('I', config_buffer, 136, bitrate) + struct.pack_into('I', config_buffer, 192, 30) if aom_library.aom_codec_enc_init_ver(aom_encoder, ctypes.byref(aom_codec), config_buffer, 0, 25) == 0: aom_library.aom_codec_control(aom_encoder, 13, ctypes.c_int(cpu_count)) @@ -37,15 +38,12 @@ def encode_aom_buffer(aom_encoder : AomEncoder, input_buffer : bytes, frame_reso output_buffer = b'' if aom_library: - temp_buffer = ctypes.create_string_buffer(512) + temp_buffer = ctypes.create_string_buffer(256) encode_buffer = ctypes.create_string_buffer(input_buffer) if aom_library.aom_img_wrap(temp_buffer, 0x102, frame_resolution[0], frame_resolution[1], 1, encode_buffer) and aom_library.aom_codec_encode(aom_encoder, temp_buffer, frame_index, 1, 0, 1) == 0: output_buffer = collect_aom_buffer(aom_encoder) - if output_buffer.startswith(bytes([ 0x12, 0x00 ])): - output_buffer = output_buffer[2:] - return output_buffer @@ -67,6 +65,47 @@ def collect_aom_buffer(aom_encoder : AomEncoder) -> bytes: return output_buffer +# TODO: try to eliminate this +def extract_aom_obus(frame_buffer : bytes) -> List[bytes]: + obu_list : List[bytes] = [] + offset = 0 + + while offset < len(frame_buffer): + header_offset = offset + header = frame_buffer[offset] + obu_type = (header >> 3) & 0x0F + has_extension = (header >> 2) & 0x01 + has_size = (header >> 1) & 0x01 + offset += 1 + has_extension + + obu_size = 0 + + if has_size: + shift = 0 + + while offset < len(frame_buffer): + leb_byte = frame_buffer[offset] + offset += 1 + obu_size |= (leb_byte & 0x7F) << shift + shift += 7 + + if not (leb_byte & 0x80): + break + + payload_offset = offset + offset += obu_size + + if obu_type != 2: + clean_header = bytes([ header & 0xFD ]) + + if has_extension: + clean_header += frame_buffer[header_offset + 1:header_offset + 2] + + obu_list.append(clean_header + frame_buffer[payload_offset:payload_offset + obu_size]) + + return obu_list + + def destroy_aom_encoder(aom_encoder : AomEncoder) -> None: aom_library = aom_module.create_static_library() diff --git a/facefusion/core.py b/facefusion/core.py index a71c29a5..bb26e8d6 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -16,7 +16,7 @@ from facefusion.filesystem import get_file_extension, has_audio, has_image, has_ from facefusion.filesystem import get_file_name, resolve_file_paths, resolve_file_pattern from facefusion.jobs import job_helper, job_manager, job_runner from facefusion.jobs.job_list import compose_job_list -from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module +from facefusion.libraries import aom as aom_module, datachannel as datachannel_module, opus as opus_module, vpx as vpx_module from facefusion.processors.core import get_processors_modules from facefusion.program import create_program from facefusion.program_helper import validate_args @@ -105,6 +105,7 @@ def pre_check() -> bool: def common_pre_check() -> bool: common_modules =\ [ + aom_module, datachannel_module, content_analyser, face_classifier, diff --git a/facefusion/libraries/datachannel.py b/facefusion/libraries/datachannel.py index bc8dbeb5..9db8dda3 100644 --- a/facefusion/libraries/datachannel.py +++ b/facefusion/libraries/datachannel.py @@ -185,6 +185,7 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL: library.rtcSendMessage.argtypes = [ ctypes.c_int, ctypes.c_void_p, ctypes.c_int ] library.rtcSendMessage.restype = ctypes.c_int + library.rtcSetAV1Packetizer.restype = ctypes.c_int library.rtcSetVP8Packetizer.restype = ctypes.c_int library.rtcChainRtcpSrReporter.argtypes = [ ctypes.c_int ] @@ -256,6 +257,7 @@ def define_rtc_packetizer_init() -> ctypes.Structure: ('clockRate', ctypes.c_uint32), ('sequenceNumber', ctypes.c_uint16), ('timestamp', ctypes.c_uint32), - ('maxFragmentSize', ctypes.c_uint16) + ('maxFragmentSize', ctypes.c_uint16), + ('obuPacketization', ctypes.c_int) ] })() diff --git a/facefusion/rtc.py b/facefusion/rtc.py index f401f324..edf74b02 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -4,7 +4,7 @@ import time from typing import Dict, List, Optional from facefusion.libraries import datachannel as datachannel_module -from facefusion.types import MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer +from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec # TODO: reduce to only used params @@ -62,18 +62,27 @@ def build_media_description(media_type : str, payload_type : int, rtp_codec : st def parse_sdp_payload_types(sdp_offer : SdpOffer) -> Dict[str, int]: payload_types : Dict[str, int] = {} + # TODO: consider having a codec helper to resolve these for line in sdp_offer.splitlines(): - if line.startswith('a=rtpmap:') and 'VP8/90000' in line: + if line.startswith('a=rtpmap:') and 'AV1/90000' in line and not payload_types.get('av1'): + payload_types['av1'] = int(line.split(':')[1].split(' ')[0]) + if line.startswith('a=rtpmap:') and 'VP8/90000' in line and not payload_types.get('vp8'): payload_types['vp8'] = int(line.split(':')[1].split(' ')[0]) - if line.startswith('a=rtpmap:') and 'opus/48000/2' in line: + if line.startswith('a=rtpmap:') and 'opus/48000/2' in line and not payload_types.get('opus'): payload_types['opus'] = int(line.split(':')[1].split(' ')[0]) return payload_types -def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDirection, payload_type : int) -> RtcAudioTrack: +def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDirection, audio_codec : AudioCodec, payload_type : int) -> RtcAudioTrack: datachannel_library = datachannel_module.create_static_library() - media_description = build_media_description('audio', payload_type, 'opus/48000/2', media_direction, 1) + + # TODO: Fix me via resolve method + rtp_codec = 'opus/48000/2' + if audio_codec == 'opus': + rtp_codec = 'opus/48000/2' + + media_description = build_media_description('audio', payload_type, rtp_codec, media_direction, 1) audio_track = datachannel_library.rtcAddTrack(peer_connection, media_description) @@ -83,15 +92,25 @@ def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDir audio_packetizer.payloadType = payload_type audio_packetizer.clockRate = 48000 - datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer)) + if audio_codec == 'opus': + datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer)) + datachannel_library.rtcChainRtcpSrReporter(audio_track) return audio_track -def add_video_track(peer_connection : PeerConnection, media_direction : MediaDirection, payload_type : int) -> RtcVideoTrack: +def add_video_track(peer_connection : PeerConnection, media_direction : MediaDirection, video_codec : VideoCodec, payload_type : int) -> RtcVideoTrack: datachannel_library = datachannel_module.create_static_library() - media_description = build_media_description('video', payload_type, 'VP8/90000', media_direction, 0) + + #TODO: Fix me via resolve method + rtp_codec = 'AV1/90000' + if video_codec == 'av1': + rtp_codec = 'AV1/90000' + if video_codec == 'vp8': + rtp_codec = 'VP8/90000' + + media_description = build_media_description('video', payload_type, rtp_codec, media_direction, 0) video_track = datachannel_library.rtcAddTrack(peer_connection, media_description) @@ -102,7 +121,11 @@ def add_video_track(peer_connection : PeerConnection, media_direction : MediaDir video_packetizer.clockRate = 90000 video_packetizer.maxFragmentSize = 1200 - datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer)) + if video_codec == 'av1': + datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer)) + if video_codec == 'vp8': + datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer)) + datachannel_library.rtcChainRtcpSrReporter(video_track) datachannel_library.rtcChainRtcpNackResponder(video_track, 512) diff --git a/facefusion/rtc_store.py b/facefusion/rtc_store.py index 07b9ba67..9725ad45 100644 --- a/facefusion/rtc_store.py +++ b/facefusion/rtc_store.py @@ -1,21 +1,23 @@ from typing import List, Optional from facefusion import rtc -from facefusion.types import PeerConnection, RtcAudioTrack, RtcPeer, RtcStreamStore, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId +from facefusion.types import AudioCodec, PeerConnection, RtcAudioTrack, RtcPeer, RtcStreamStore, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec -RTC_STREAMS : RtcStreamStore = {} + +# TODO: aint this a peer store? +RTC_STREAM_STORE : RtcStreamStore = {} def get_rtc_stream(session_id : SessionId) -> Optional[List[RtcPeer]]: - return RTC_STREAMS.get(session_id) + return RTC_STREAM_STORE.get(session_id) def create_rtc_stream(session_id : SessionId) -> None: - RTC_STREAMS[session_id] = [] + RTC_STREAM_STORE[session_id] = [] def destroy_rtc_stream(session_id : SessionId) -> None: - rtc_peers = RTC_STREAMS.pop(session_id, None) + rtc_peers = RTC_STREAM_STORE.pop(session_id, None) if rtc_peers: rtc.delete_peers(rtc_peers) @@ -23,11 +25,20 @@ def destroy_rtc_stream(session_id : SessionId) -> None: # TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]: - if session_id in RTC_STREAMS: + if session_id in RTC_STREAM_STORE: payload_types = rtc.parse_sdp_payload_types(sdp_offer) peer_connection : PeerConnection = rtc.create_peer_connection() - audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', payload_types.get('opus', 111)) - video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', payload_types.get('vp8', 96)) + audio_codec : AudioCodec = 'opus' + audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, payload_types.get(audio_codec, 111)) + + #TODO: Fix me via resolve method + video_codec : VideoCodec = 'av1' + if payload_types.get('av1'): + video_codec = 'av1' + if payload_types.get('vp8'): + video_codec = 'vp8' + + video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', video_codec, payload_types.get(video_codec, 96)) local_sdp = rtc.negotiate_sdp(peer_connection, sdp_offer) if local_sdp: @@ -37,7 +48,7 @@ def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[Sdp 'video_track': video_track, 'audio_track': audio_track } - RTC_STREAMS[session_id].append(rtc_peer) + RTC_STREAM_STORE[session_id].append(rtc_peer) return local_sdp diff --git a/facefusion/types.py b/facefusion/types.py index e41da57f..7b2e17b3 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -90,6 +90,9 @@ MelFilterBank : TypeAlias = NDArray[Any] Voice : TypeAlias = NDArray[Any] VoiceChunk : TypeAlias = NDArray[Any] +AudioCodec : TypeAlias = Literal['opus'] +VideoCodec : TypeAlias = Literal['av1', 'vp8'] + AomEncoder : TypeAlias = ctypes.Array[ctypes.c_char] OpusEncoder : TypeAlias = ctypes.c_void_p VpxEncoder : TypeAlias = ctypes.Array[ctypes.c_char] @@ -267,13 +270,15 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet', WebcamMode = Literal['inline', 'udp', 'v4l2'] StreamMode = Literal['udp', 'v4l2'] -RtcVideoTrack : TypeAlias = int -RtcAudioTrack : TypeAlias = int + PeerConnection : TypeAlias = int SdpOffer : TypeAlias = str SdpAnswer : TypeAlias = str MediaDirection : TypeAlias = Literal['sendonly', 'recvonly', 'sendrecv', 'inactive'] +RtcVideoTrack : TypeAlias = int +RtcAudioTrack : TypeAlias = int + RtcPeer = TypedDict('RtcPeer', { 'peer_connection': PeerConnection, @@ -281,7 +286,7 @@ RtcPeer = TypedDict('RtcPeer', 'audio_track': RtcAudioTrack, }) -RtcStreamStore : TypeAlias = Dict[str, List[RtcPeer]] +RtcStreamStore : TypeAlias = Dict[SessionId, List[RtcPeer]] ModelOptions : TypeAlias = Dict[str, Any] ModelSet : TypeAlias = Dict[str, ModelOptions] diff --git a/tests/test_codec_aom.py b/tests/test_codec_aom.py index 205ab327..7548d368 100644 --- a/tests/test_codec_aom.py +++ b/tests/test_codec_aom.py @@ -34,7 +34,7 @@ def test_encode_aom_buffer() -> None: aom_encoder = create_aom_encoder(video_resolution, 1000, 1, 0) if is_linux() or is_windows(): - assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '4b621fb8' + assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31' if is_macos(): assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '64c12977' diff --git a/tests/test_rtc.py b/tests/test_rtc.py index bbc1513b..a0d221d0 100644 --- a/tests/test_rtc.py +++ b/tests/test_rtc.py @@ -33,7 +33,7 @@ def test_create_peer_connection() -> None: def test_add_audio_track() -> None: peer_connection = rtc.create_peer_connection() - assert rtc.add_audio_track(peer_connection, 'sendonly', 111) > 0 + assert rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) > 0 datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) @@ -41,7 +41,7 @@ def test_add_audio_track() -> None: def test_add_video_track() -> None: peer_connection = rtc.create_peer_connection() - assert rtc.add_video_track(peer_connection, 'sendonly', 96) > 0 + assert rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) > 0 datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection) @@ -50,13 +50,13 @@ def test_negotiate_sdp() -> None: datachannel_library = datachannel_module.create_static_library() sender_connection = rtc.create_peer_connection() - rtc.add_video_track(sender_connection, 'sendonly', 96) - rtc.add_audio_track(sender_connection, 'sendonly', 111) + rtc.add_video_track(sender_connection, 'sendonly', 'vp8', 96) + rtc.add_audio_track(sender_connection, 'sendonly', 'opus', 111) sdp_offer = rtc.create_sdp(sender_connection) receiver_connection = rtc.create_peer_connection() - rtc.add_video_track(receiver_connection, 'recvonly', 96) - rtc.add_audio_track(receiver_connection, 'recvonly', 111) + rtc.add_video_track(receiver_connection, 'recvonly', 'vp8', 96) + rtc.add_audio_track(receiver_connection, 'recvonly', 'opus', 111) sdp_answer = rtc.negotiate_sdp(receiver_connection, sdp_offer) assert sdp_answer