av1 support integrated (#1112)

This commit is contained in:
Henry Ruhs
2026-05-14 16:11:23 +02:00
committed by GitHub
parent b607e4a99e
commit 18a487347a
9 changed files with 164 additions and 42 deletions
+46 -5
View File
@@ -1,7 +1,7 @@
import asyncio
from collections import deque
from collections.abc import AsyncIterator
from typing import Tuple
from typing import Tuple, cast, get_args
import cv2
import numpy
@@ -10,10 +10,11 @@ from starlette.websockets import WebSocket, WebSocketState
from facefusion import rtc_store, session_context, session_manager, state_manager
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.apis.session_helper import extract_access_token
from facefusion.codecs.aom import create_aom_encoder, destroy_aom_encoder, encode_aom_buffer, extract_aom_obus
from facefusion.codecs.opus import create_opus_encoder, destroy_opus_encoder, encode_opus_buffer
from facefusion.codecs.vpx import create_vpx_encoder, destroy_vpx_encoder, encode_vpx_buffer
from facefusion.streamer import process_vision_frame
from facefusion.types import Resolution, SessionId, VisionFrame
from facefusion.types import Resolution, SessionId, VideoCodec, VisionFrame
async def receive_stream_frames(websocket : WebSocket) -> AsyncIterator[Tuple[int, bytes]]:
@@ -41,8 +42,39 @@ async def receive_vision_frames(websocket : WebSocket) -> AsyncIterator[VisionFr
websocket_event = await websocket.receive()
# TODO: move to facefusion/vpx_encoder.py, throttle loop to avoid spinning on same frame
def run_video_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
def run_aom_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
aom_encoder = create_aom_encoder(initial_resolution, 4500, 8, 10)
current_resolution = initial_resolution
pts = 0
while vision_frame_deque:
vision_frame = vision_frame_deque[-1]
output_frame = process_vision_frame(vision_frame)
frame_resolution = (output_frame.shape[1], output_frame.shape[0])
if frame_resolution[0] != current_resolution[0] or frame_resolution[1] != current_resolution[1]:
if aom_encoder:
destroy_aom_encoder(aom_encoder)
current_resolution = frame_resolution
aom_encoder = create_aom_encoder(current_resolution, 4500, 8, 10)
pts = 0
if aom_encoder:
yuv_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2YUV_I420)
frame_buffer = encode_aom_buffer(aom_encoder, yuv_frame.tobytes(), frame_resolution, pts)
if frame_buffer:
for obu_buffer in extract_aom_obus(frame_buffer):
rtc_store.send_rtc_video(session_id, obu_buffer)
pts += 1
if aom_encoder:
destroy_aom_encoder(aom_encoder)
def run_vp8_encode_loop(vision_frame_deque : deque[VisionFrame], session_id : SessionId, initial_resolution : Resolution, keyframe_interval : int) -> None:
vpx_encoder = create_vpx_encoder(initial_resolution, 4500, 8, 16)
current_resolution = initial_resolution
pts = 0
@@ -102,6 +134,10 @@ async def handle_video_stream(websocket : WebSocket) -> None:
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
stream_codec : VideoCodec = 'av1'
if websocket.query_params.get('codec') in get_args(VideoCodec):
stream_codec = cast(VideoCodec, websocket.query_params.get('codec'))
await websocket.accept(subprotocol = subprotocol)
@@ -127,7 +163,12 @@ async def handle_video_stream(websocket : WebSocket) -> None:
rtc_store.create_rtc_stream(session_id)
event_loop = asyncio.get_running_loop()
video_encode_task = event_loop.run_in_executor(None, run_video_encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval)
encode_loop = run_aom_encode_loop
if stream_codec == 'vp8':
encode_loop = run_vp8_encode_loop
video_encode_task = event_loop.run_in_executor(None, encode_loop, vision_frame_deque, session_id, resolution, keyframe_interval)
await websocket.send_text('ready')
async for frame_type, frame_buffer in stream_frames:
+46 -7
View File
@@ -1,6 +1,6 @@
import ctypes
import struct
from typing import Optional
from typing import List, Optional
from facefusion.libraries import aom as aom_module
from facefusion.types import AomEncoder, BitRate, Resolution
@@ -10,16 +10,17 @@ def create_aom_encoder(frame_resolution : Resolution, bitrate : BitRate, thread_
aom_library = aom_module.create_static_library()
if aom_library:
aom_encoder = ctypes.create_string_buffer(1024)
aom_encoder = ctypes.create_string_buffer(128)
aom_codec = ctypes.c_void_p.in_dll(aom_library, 'aom_codec_av1_cx_algo')
config_buffer = ctypes.create_string_buffer(4096)
config_buffer = ctypes.create_string_buffer(1024)
if aom_library.aom_codec_enc_config_default(ctypes.byref(aom_codec), config_buffer, 1) == 0:
struct.pack_into('I', config_buffer, 4, thread_count)
struct.pack_into('I', config_buffer, 12, frame_resolution[0])
struct.pack_into('I', config_buffer, 16, frame_resolution[1])
struct.pack_into('I', config_buffer, 136, bitrate)
struct.pack_into('I', config_buffer, 192, 30)
if aom_library.aom_codec_enc_init_ver(aom_encoder, ctypes.byref(aom_codec), config_buffer, 0, 25) == 0:
aom_library.aom_codec_control(aom_encoder, 13, ctypes.c_int(cpu_count))
@@ -37,15 +38,12 @@ def encode_aom_buffer(aom_encoder : AomEncoder, input_buffer : bytes, frame_reso
output_buffer = b''
if aom_library:
temp_buffer = ctypes.create_string_buffer(512)
temp_buffer = ctypes.create_string_buffer(256)
encode_buffer = ctypes.create_string_buffer(input_buffer)
if aom_library.aom_img_wrap(temp_buffer, 0x102, frame_resolution[0], frame_resolution[1], 1, encode_buffer) and aom_library.aom_codec_encode(aom_encoder, temp_buffer, frame_index, 1, 0, 1) == 0:
output_buffer = collect_aom_buffer(aom_encoder)
if output_buffer.startswith(bytes([ 0x12, 0x00 ])):
output_buffer = output_buffer[2:]
return output_buffer
@@ -67,6 +65,47 @@ def collect_aom_buffer(aom_encoder : AomEncoder) -> bytes:
return output_buffer
# TODO: try to eliminate this
def extract_aom_obus(frame_buffer : bytes) -> List[bytes]:
obu_list : List[bytes] = []
offset = 0
while offset < len(frame_buffer):
header_offset = offset
header = frame_buffer[offset]
obu_type = (header >> 3) & 0x0F
has_extension = (header >> 2) & 0x01
has_size = (header >> 1) & 0x01
offset += 1 + has_extension
obu_size = 0
if has_size:
shift = 0
while offset < len(frame_buffer):
leb_byte = frame_buffer[offset]
offset += 1
obu_size |= (leb_byte & 0x7F) << shift
shift += 7
if not (leb_byte & 0x80):
break
payload_offset = offset
offset += obu_size
if obu_type != 2:
clean_header = bytes([ header & 0xFD ])
if has_extension:
clean_header += frame_buffer[header_offset + 1:header_offset + 2]
obu_list.append(clean_header + frame_buffer[payload_offset:payload_offset + obu_size])
return obu_list
def destroy_aom_encoder(aom_encoder : AomEncoder) -> None:
aom_library = aom_module.create_static_library()
+2 -1
View File
@@ -16,7 +16,7 @@ from facefusion.filesystem import get_file_extension, has_audio, has_image, has_
from facefusion.filesystem import get_file_name, resolve_file_paths, resolve_file_pattern
from facefusion.jobs import job_helper, job_manager, job_runner
from facefusion.jobs.job_list import compose_job_list
from facefusion.libraries import datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
from facefusion.libraries import aom as aom_module, datachannel as datachannel_module, opus as opus_module, vpx as vpx_module
from facefusion.processors.core import get_processors_modules
from facefusion.program import create_program
from facefusion.program_helper import validate_args
@@ -105,6 +105,7 @@ def pre_check() -> bool:
def common_pre_check() -> bool:
common_modules =\
[
aom_module,
datachannel_module,
content_analyser,
face_classifier,
+3 -1
View File
@@ -185,6 +185,7 @@ def init_ctypes(library : ctypes.CDLL) -> ctypes.CDLL:
library.rtcSendMessage.argtypes = [ ctypes.c_int, ctypes.c_void_p, ctypes.c_int ]
library.rtcSendMessage.restype = ctypes.c_int
library.rtcSetAV1Packetizer.restype = ctypes.c_int
library.rtcSetVP8Packetizer.restype = ctypes.c_int
library.rtcChainRtcpSrReporter.argtypes = [ ctypes.c_int ]
@@ -256,6 +257,7 @@ def define_rtc_packetizer_init() -> ctypes.Structure:
('clockRate', ctypes.c_uint32),
('sequenceNumber', ctypes.c_uint16),
('timestamp', ctypes.c_uint32),
('maxFragmentSize', ctypes.c_uint16)
('maxFragmentSize', ctypes.c_uint16),
('obuPacketization', ctypes.c_int)
]
})()
+32 -9
View File
@@ -4,7 +4,7 @@ import time
from typing import Dict, List, Optional
from facefusion.libraries import datachannel as datachannel_module
from facefusion.types import MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer
from facefusion.types import AudioCodec, MediaDirection, PeerConnection, RtcAudioTrack, RtcPeer, RtcVideoTrack, SdpAnswer, SdpOffer, VideoCodec
# TODO: reduce to only used params
@@ -62,18 +62,27 @@ def build_media_description(media_type : str, payload_type : int, rtp_codec : st
def parse_sdp_payload_types(sdp_offer : SdpOffer) -> Dict[str, int]:
payload_types : Dict[str, int] = {}
# TODO: consider having a codec helper to resolve these
for line in sdp_offer.splitlines():
if line.startswith('a=rtpmap:') and 'VP8/90000' in line:
if line.startswith('a=rtpmap:') and 'AV1/90000' in line and not payload_types.get('av1'):
payload_types['av1'] = int(line.split(':')[1].split(' ')[0])
if line.startswith('a=rtpmap:') and 'VP8/90000' in line and not payload_types.get('vp8'):
payload_types['vp8'] = int(line.split(':')[1].split(' ')[0])
if line.startswith('a=rtpmap:') and 'opus/48000/2' in line:
if line.startswith('a=rtpmap:') and 'opus/48000/2' in line and not payload_types.get('opus'):
payload_types['opus'] = int(line.split(':')[1].split(' ')[0])
return payload_types
def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDirection, payload_type : int) -> RtcAudioTrack:
def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDirection, audio_codec : AudioCodec, payload_type : int) -> RtcAudioTrack:
datachannel_library = datachannel_module.create_static_library()
media_description = build_media_description('audio', payload_type, 'opus/48000/2', media_direction, 1)
# TODO: Fix me via resolve method
rtp_codec = 'opus/48000/2'
if audio_codec == 'opus':
rtp_codec = 'opus/48000/2'
media_description = build_media_description('audio', payload_type, rtp_codec, media_direction, 1)
audio_track = datachannel_library.rtcAddTrack(peer_connection, media_description)
@@ -83,15 +92,25 @@ def add_audio_track(peer_connection : PeerConnection, media_direction : MediaDir
audio_packetizer.payloadType = payload_type
audio_packetizer.clockRate = 48000
datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer))
if audio_codec == 'opus':
datachannel_library.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer))
datachannel_library.rtcChainRtcpSrReporter(audio_track)
return audio_track
def add_video_track(peer_connection : PeerConnection, media_direction : MediaDirection, payload_type : int) -> RtcVideoTrack:
def add_video_track(peer_connection : PeerConnection, media_direction : MediaDirection, video_codec : VideoCodec, payload_type : int) -> RtcVideoTrack:
datachannel_library = datachannel_module.create_static_library()
media_description = build_media_description('video', payload_type, 'VP8/90000', media_direction, 0)
#TODO: Fix me via resolve method
rtp_codec = 'AV1/90000'
if video_codec == 'av1':
rtp_codec = 'AV1/90000'
if video_codec == 'vp8':
rtp_codec = 'VP8/90000'
media_description = build_media_description('video', payload_type, rtp_codec, media_direction, 0)
video_track = datachannel_library.rtcAddTrack(peer_connection, media_description)
@@ -102,7 +121,11 @@ def add_video_track(peer_connection : PeerConnection, media_direction : MediaDir
video_packetizer.clockRate = 90000
video_packetizer.maxFragmentSize = 1200
datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer))
if video_codec == 'av1':
datachannel_library.rtcSetAV1Packetizer(video_track, ctypes.byref(video_packetizer))
if video_codec == 'vp8':
datachannel_library.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer))
datachannel_library.rtcChainRtcpSrReporter(video_track)
datachannel_library.rtcChainRtcpNackResponder(video_track, 512)
+20 -9
View File
@@ -1,21 +1,23 @@
from typing import List, Optional
from facefusion import rtc
from facefusion.types import PeerConnection, RtcAudioTrack, RtcPeer, RtcStreamStore, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId
from facefusion.types import AudioCodec, PeerConnection, RtcAudioTrack, RtcPeer, RtcStreamStore, RtcVideoTrack, SdpAnswer, SdpOffer, SessionId, VideoCodec
RTC_STREAMS : RtcStreamStore = {}
# TODO: aint this a peer store?
RTC_STREAM_STORE : RtcStreamStore = {}
def get_rtc_stream(session_id : SessionId) -> Optional[List[RtcPeer]]:
return RTC_STREAMS.get(session_id)
return RTC_STREAM_STORE.get(session_id)
def create_rtc_stream(session_id : SessionId) -> None:
RTC_STREAMS[session_id] = []
RTC_STREAM_STORE[session_id] = []
def destroy_rtc_stream(session_id : SessionId) -> None:
rtc_peers = RTC_STREAMS.pop(session_id, None)
rtc_peers = RTC_STREAM_STORE.pop(session_id, None)
if rtc_peers:
rtc.delete_peers(rtc_peers)
@@ -23,11 +25,20 @@ def destroy_rtc_stream(session_id : SessionId) -> None:
# TODO: clean up peer connection on failed sdp negotiation, wrap in run_in_executor to avoid blocking async event loop
def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[SdpAnswer]:
if session_id in RTC_STREAMS:
if session_id in RTC_STREAM_STORE:
payload_types = rtc.parse_sdp_payload_types(sdp_offer)
peer_connection : PeerConnection = rtc.create_peer_connection()
audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', payload_types.get('opus', 111))
video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', payload_types.get('vp8', 96))
audio_codec : AudioCodec = 'opus'
audio_track : RtcAudioTrack = rtc.add_audio_track(peer_connection, 'sendonly', audio_codec, payload_types.get(audio_codec, 111))
#TODO: Fix me via resolve method
video_codec : VideoCodec = 'av1'
if payload_types.get('av1'):
video_codec = 'av1'
if payload_types.get('vp8'):
video_codec = 'vp8'
video_track : RtcVideoTrack = rtc.add_video_track(peer_connection, 'sendonly', video_codec, payload_types.get(video_codec, 96))
local_sdp = rtc.negotiate_sdp(peer_connection, sdp_offer)
if local_sdp:
@@ -37,7 +48,7 @@ def add_rtc_viewer(session_id : SessionId, sdp_offer : SdpOffer) -> Optional[Sdp
'video_track': video_track,
'audio_track': audio_track
}
RTC_STREAMS[session_id].append(rtc_peer)
RTC_STREAM_STORE[session_id].append(rtc_peer)
return local_sdp
+8 -3
View File
@@ -90,6 +90,9 @@ MelFilterBank : TypeAlias = NDArray[Any]
Voice : TypeAlias = NDArray[Any]
VoiceChunk : TypeAlias = NDArray[Any]
AudioCodec : TypeAlias = Literal['opus']
VideoCodec : TypeAlias = Literal['av1', 'vp8']
AomEncoder : TypeAlias = ctypes.Array[ctypes.c_char]
OpusEncoder : TypeAlias = ctypes.c_void_p
VpxEncoder : TypeAlias = ctypes.Array[ctypes.c_char]
@@ -267,13 +270,15 @@ BenchmarkCycleSet = TypedDict('BenchmarkCycleSet',
WebcamMode = Literal['inline', 'udp', 'v4l2']
StreamMode = Literal['udp', 'v4l2']
RtcVideoTrack : TypeAlias = int
RtcAudioTrack : TypeAlias = int
PeerConnection : TypeAlias = int
SdpOffer : TypeAlias = str
SdpAnswer : TypeAlias = str
MediaDirection : TypeAlias = Literal['sendonly', 'recvonly', 'sendrecv', 'inactive']
RtcVideoTrack : TypeAlias = int
RtcAudioTrack : TypeAlias = int
RtcPeer = TypedDict('RtcPeer',
{
'peer_connection': PeerConnection,
@@ -281,7 +286,7 @@ RtcPeer = TypedDict('RtcPeer',
'audio_track': RtcAudioTrack,
})
RtcStreamStore : TypeAlias = Dict[str, List[RtcPeer]]
RtcStreamStore : TypeAlias = Dict[SessionId, List[RtcPeer]]
ModelOptions : TypeAlias = Dict[str, Any]
ModelSet : TypeAlias = Dict[str, ModelOptions]
+1 -1
View File
@@ -34,7 +34,7 @@ def test_encode_aom_buffer() -> None:
aom_encoder = create_aom_encoder(video_resolution, 1000, 1, 0)
if is_linux() or is_windows():
assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '4b621fb8'
assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '3ab6cc31'
if is_macos():
assert create_hash(encode_aom_buffer(aom_encoder, video_buffer, video_resolution, 3)) == '64c12977'
+6 -6
View File
@@ -33,7 +33,7 @@ def test_create_peer_connection() -> None:
def test_add_audio_track() -> None:
peer_connection = rtc.create_peer_connection()
assert rtc.add_audio_track(peer_connection, 'sendonly', 111) > 0
assert rtc.add_audio_track(peer_connection, 'sendonly', 'opus', 111) > 0
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
@@ -41,7 +41,7 @@ def test_add_audio_track() -> None:
def test_add_video_track() -> None:
peer_connection = rtc.create_peer_connection()
assert rtc.add_video_track(peer_connection, 'sendonly', 96) > 0
assert rtc.add_video_track(peer_connection, 'sendonly', 'vp8', 96) > 0
datachannel_module.create_static_library().rtcDeletePeerConnection(peer_connection)
@@ -50,13 +50,13 @@ def test_negotiate_sdp() -> None:
datachannel_library = datachannel_module.create_static_library()
sender_connection = rtc.create_peer_connection()
rtc.add_video_track(sender_connection, 'sendonly', 96)
rtc.add_audio_track(sender_connection, 'sendonly', 111)
rtc.add_video_track(sender_connection, 'sendonly', 'vp8', 96)
rtc.add_audio_track(sender_connection, 'sendonly', 'opus', 111)
sdp_offer = rtc.create_sdp(sender_connection)
receiver_connection = rtc.create_peer_connection()
rtc.add_video_track(receiver_connection, 'recvonly', 96)
rtc.add_audio_track(receiver_connection, 'recvonly', 111)
rtc.add_video_track(receiver_connection, 'recvonly', 'vp8', 96)
rtc.add_audio_track(receiver_connection, 'recvonly', 'opus', 111)
sdp_answer = rtc.negotiate_sdp(receiver_connection, sdp_offer)
assert sdp_answer