mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-25 19:06:21 +02:00
Dynamic bitrate for webrtc stream (#1063)
* use custom aiortc * update naming to bitrate_
This commit is contained in:
@@ -11,7 +11,7 @@ from starlette.websockets import WebSocket
|
||||
from facefusion import session_context, session_manager, state_manager
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import on_video_track
|
||||
from facefusion.apis.stream_helper import create_output_track, on_video_track
|
||||
from facefusion.streamer import process_vision_frame
|
||||
|
||||
|
||||
@@ -51,10 +51,18 @@ async def webrtc_stream(request : Request) -> Response:
|
||||
|
||||
if session_id:
|
||||
body = await request.json()
|
||||
buffer_size = int(body.get('buffer_size', 30))
|
||||
bitrate_init = int(body.get('bitrate_init', 100000))
|
||||
bitrate_min = int(body.get('bitrate_min', 100000))
|
||||
bitrate_max = int(body.get('bitrate_max', 4000000))
|
||||
|
||||
rtc_offer = RTCSessionDescription(sdp = body.get('sdp'), type = body.get('type'))
|
||||
rtc_connection = RTCPeerConnection()
|
||||
|
||||
rtc_connection.on('track', partial(on_video_track, rtc_connection))
|
||||
output_track, sender = create_output_track(rtc_connection, buffer_size)
|
||||
sender.configure_bitrate(bitrate_init, bitrate_min, bitrate_max)
|
||||
|
||||
rtc_connection.on('track', partial(on_video_track, rtc_connection, output_track))
|
||||
|
||||
await rtc_connection.setRemoteDescription(rtc_offer)
|
||||
await rtc_connection.setLocalDescription(await rtc_connection.createAnswer())
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from typing import Tuple
|
||||
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, VideoStreamTrack
|
||||
from aiortc import MediaStreamTrack, QueuedVideoStreamTrack, RTCPeerConnection, RTCRtpSender
|
||||
from aiortc.mediastreams import MediaStreamError
|
||||
from av import VideoFrame
|
||||
|
||||
from facefusion.streamer import process_vision_frame
|
||||
@@ -10,27 +11,34 @@ from facefusion.streamer import process_vision_frame
|
||||
def process_stream_frame(target_stream_frame : VideoFrame) -> VideoFrame:
|
||||
target_vision_frame = target_stream_frame.to_ndarray(format = 'bgr24')
|
||||
output_vision_frame = process_vision_frame(target_vision_frame)
|
||||
return VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24')
|
||||
output_stream_frame = VideoFrame.from_ndarray(output_vision_frame, format = 'bgr24')
|
||||
output_stream_frame.pts = target_stream_frame.pts
|
||||
output_stream_frame.time_base = target_stream_frame.time_base
|
||||
return output_stream_frame
|
||||
|
||||
|
||||
def create_output_track(target_track : MediaStreamTrack) -> VideoStreamTrack:
|
||||
output_track = VideoStreamTrack()
|
||||
|
||||
async def read_stream_frame() -> VideoFrame:
|
||||
target_stream_frame = cast(VideoFrame, await target_track.recv())
|
||||
output_stream_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_stream_frame)
|
||||
output_stream_frame.pts = target_stream_frame.pts
|
||||
output_stream_frame.time_base = target_stream_frame.time_base
|
||||
return output_stream_frame
|
||||
|
||||
output_track.recv = read_stream_frame
|
||||
return output_track
|
||||
def create_output_track(rtc_connection : RTCPeerConnection, buffer_size : int) -> Tuple[QueuedVideoStreamTrack, RTCRtpSender]:
|
||||
output_track = QueuedVideoStreamTrack(buffer_size = buffer_size)
|
||||
sender = rtc_connection.addTrack(output_track)
|
||||
return output_track, sender
|
||||
|
||||
|
||||
def on_video_track(rtc_connection : RTCPeerConnection, target_track : MediaStreamTrack) -> None:
|
||||
async def process_and_enqueue(target_track : MediaStreamTrack, output_track : QueuedVideoStreamTrack) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
while True:
|
||||
try:
|
||||
target_stream_frame = await target_track.recv()
|
||||
except MediaStreamError:
|
||||
pass
|
||||
|
||||
output_stream_frame = await loop.run_in_executor(None, process_stream_frame, target_stream_frame) #type:ignore[arg-type]
|
||||
await output_track.put(output_stream_frame)
|
||||
|
||||
|
||||
def on_video_track(rtc_connection : RTCPeerConnection, output_track : QueuedVideoStreamTrack, target_track : MediaStreamTrack) -> None:
|
||||
if target_track.kind == 'audio':
|
||||
rtc_connection.addTrack(target_track)
|
||||
|
||||
if target_track.kind == 'video':
|
||||
output_track = create_output_track(target_track)
|
||||
rtc_connection.addTrack(output_track)
|
||||
asyncio.create_task(process_and_enqueue(target_track, output_track))
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ nvidia-ml-py==13.590.48
|
||||
psutil==7.2.2
|
||||
tqdm==4.67.3
|
||||
scipy==1.16.3
|
||||
aiortc==1.14.0
|
||||
aiortc @ git+https://github.com/facefusion/aiortc.git@feat/dynamic-bitrate
|
||||
starlette==0.52.1
|
||||
uvicorn==0.41.0
|
||||
websockets==16.0
|
||||
|
||||
Reference in New Issue
Block a user