mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-23 01:46:09 +02:00
refactor the release candidate
This commit is contained in:
@@ -13,7 +13,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
from facefusion.apis.endpoints.stream import post_whep, websocket_stream, websocket_stream_rtc
|
||||
from facefusion.apis.endpoints.stream import post_stream, websocket_stream
|
||||
from facefusion.apis.middlewares.session import create_session_guard
|
||||
|
||||
|
||||
@@ -50,11 +50,10 @@ def create_api() -> Starlette:
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/capabilities', get_capabilities, methods = [ 'GET' ]),
|
||||
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/stream', post_stream, methods = [ 'POST' ]),
|
||||
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
|
||||
Route('/stream/{session_id}/whep', post_whep, methods = [ 'POST' ]),
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ])
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
api = Starlette(routes = routes, lifespan = lifespan)
|
||||
|
||||
@@ -15,43 +15,15 @@ from starlette.websockets import WebSocket
|
||||
from facefusion import logger, session_context, session_manager, state_manager
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, create_vp8_pipe_encoder, feed_whip_frame, process_stream_frame
|
||||
from facefusion.apis.stream_helper import STREAM_FPS, compute_bitrate, compute_bufsize, get_stream_mode, process_stream_frame
|
||||
from facefusion.ffmpeg import open_vp8_encoder, write_raw_bytes
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import VisionFrame
|
||||
from facefusion.vision import convert_to_raw_rgb
|
||||
|
||||
|
||||
JPEG_MAGIC : bytes = b'\xff\xd8'
|
||||
|
||||
|
||||
async def websocket_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
session_context.set_session_id(session_id)
|
||||
source_paths = state_manager.get_item('source_paths')
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if source_paths:
|
||||
try:
|
||||
image_buffer = await websocket.receive_bytes()
|
||||
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(target_vision_frame):
|
||||
temp_vision_frame = process_vision_frame(target_vision_frame)
|
||||
is_success, output_vision_frame = cv2.imencode('.jpg', temp_vision_frame)
|
||||
|
||||
if is_success:
|
||||
await websocket.send_bytes(output_vision_frame.tobytes())
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
await websocket.close()
|
||||
|
||||
|
||||
def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lock) -> None:
|
||||
pipe_handle = process.stdout.fileno()
|
||||
|
||||
@@ -135,11 +107,13 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s
|
||||
|
||||
if not encoder:
|
||||
height, width = temp_vision_frame.shape[:2]
|
||||
encoder = create_vp8_pipe_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
|
||||
stream_bitrate = compute_bitrate(width, height)
|
||||
stream_bufsize = compute_bufsize(width, height)
|
||||
encoder = open_vp8_encoder(width, height, STREAM_FPS, stream_bitrate, stream_bufsize)
|
||||
threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True).start()
|
||||
logger.info('vp8 direct encoder started ' + str(width) + 'x' + str(height), __name__)
|
||||
|
||||
feed_whip_frame(encoder, temp_vision_frame)
|
||||
write_raw_bytes(encoder, convert_to_raw_rgb(temp_vision_frame))
|
||||
|
||||
with vp8_lock:
|
||||
if vp8_frames:
|
||||
@@ -162,8 +136,9 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s
|
||||
encoder.wait(timeout = 5)
|
||||
|
||||
|
||||
async def websocket_stream_rtc(websocket : WebSocket) -> None:
|
||||
async def websocket_stream(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
stream_mode = get_stream_mode(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
@@ -173,57 +148,80 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None:
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if source_paths:
|
||||
from facefusion import rtc
|
||||
if stream_mode == 'video':
|
||||
await handle_video_stream(websocket, session_id)
|
||||
return
|
||||
|
||||
stream_path = 'stream/' + session_id
|
||||
rtc.create_session(stream_path)
|
||||
whep_url = '/' + stream_path + '/whep'
|
||||
|
||||
latest_frame_holder : list = [None]
|
||||
whep_sent = False
|
||||
lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
ready_event = threading.Event()
|
||||
worker = threading.Thread(target = run_rtc_direct_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, stream_path), daemon = True)
|
||||
worker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
if not whep_sent and ready_event.is_set():
|
||||
await websocket.send_text(whep_url)
|
||||
whep_sent = True
|
||||
|
||||
if message.get('bytes'):
|
||||
data = message.get('bytes')
|
||||
|
||||
if data[:2] == JPEG_MAGIC:
|
||||
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(frame):
|
||||
with lock:
|
||||
latest_frame_holder[0] = frame
|
||||
|
||||
if data[:2] != JPEG_MAGIC:
|
||||
rtc.send_audio(stream_path, data)
|
||||
|
||||
except Exception as exception:
|
||||
logger.error(str(exception), __name__)
|
||||
|
||||
stop_event.set()
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, worker.join, 10)
|
||||
rtc.destroy_session(stream_path)
|
||||
await handle_image_stream(websocket)
|
||||
return
|
||||
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def post_whep(request : Request) -> Response:
|
||||
async def handle_image_stream(websocket : WebSocket) -> None:
|
||||
try:
|
||||
image_buffer = await websocket.receive_bytes()
|
||||
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(target_vision_frame):
|
||||
temp_vision_frame = process_vision_frame(target_vision_frame)
|
||||
is_success, output_vision_frame = cv2.imencode('.jpg', temp_vision_frame)
|
||||
|
||||
if is_success:
|
||||
await websocket.send_bytes(output_vision_frame.tobytes())
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def handle_video_stream(websocket : WebSocket, session_id : str) -> None:
|
||||
from facefusion import rtc
|
||||
|
||||
session_id = request.path_params.get('session_id')
|
||||
stream_path = 'stream/' + session_id
|
||||
rtc.create_session(stream_path)
|
||||
|
||||
latest_frame_holder : list = [None]
|
||||
ready_sent = False
|
||||
lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
ready_event = threading.Event()
|
||||
worker = threading.Thread(target = run_rtc_direct_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, stream_path), daemon = True)
|
||||
worker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
if not ready_sent and ready_event.is_set():
|
||||
await websocket.send_text('ready:' + session_id)
|
||||
ready_sent = True
|
||||
|
||||
if message.get('bytes'):
|
||||
data = message.get('bytes')
|
||||
|
||||
if data[:2] == JPEG_MAGIC:
|
||||
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(frame):
|
||||
with lock:
|
||||
latest_frame_holder[0] = frame
|
||||
|
||||
if data[:2] != JPEG_MAGIC:
|
||||
rtc.send_audio(stream_path, data)
|
||||
|
||||
except Exception as exception:
|
||||
logger.error(str(exception), __name__)
|
||||
|
||||
stop_event.set()
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, worker.join, 10)
|
||||
rtc.destroy_session(stream_path)
|
||||
|
||||
|
||||
async def post_stream(request : Request) -> Response:
|
||||
from facefusion import rtc
|
||||
|
||||
session_id = request.query_params.get('session_id')
|
||||
stream_path = 'stream/' + session_id
|
||||
body = await request.body()
|
||||
sdp_offer = body.decode('utf-8')
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import Scope
|
||||
|
||||
from facefusion import ffmpeg_builder
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import VisionFrame
|
||||
|
||||
@@ -38,40 +38,18 @@ def compute_bufsize(width : int, height : int) -> str:
|
||||
return '10000k'
|
||||
|
||||
|
||||
def create_vp8_pipe_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
|
||||
commands = ffmpeg_builder.chain(
|
||||
[ '-use_wallclock_as_timestamps', '1' ],
|
||||
ffmpeg_builder.capture_video(),
|
||||
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
|
||||
ffmpeg_builder.set_input('-'),
|
||||
[ '-c:v', 'libvpx' ],
|
||||
[ '-deadline', 'realtime' ],
|
||||
[ '-cpu-used', '8' ],
|
||||
[ '-pix_fmt', 'yuv420p' ],
|
||||
[ '-crf', '10' ],
|
||||
[ '-b:v', compute_bitrate(width, height) ],
|
||||
[ '-maxrate', compute_bitrate(width, height) ],
|
||||
[ '-bufsize', compute_bufsize(width, height) ],
|
||||
[ '-g', str(stream_fps) ],
|
||||
[ '-keyint_min', str(stream_fps) ],
|
||||
[ '-error-resilient', '1' ],
|
||||
[ '-lag-in-frames', '0' ],
|
||||
[ '-rc_lookahead', '0' ],
|
||||
[ '-threads', '4' ],
|
||||
[ '-an' ],
|
||||
[ '-f', 'ivf' ],
|
||||
ffmpeg_builder.set_output('-')
|
||||
)
|
||||
commands = ffmpeg_builder.run(commands)
|
||||
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
|
||||
return process
|
||||
|
||||
|
||||
def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None:
|
||||
raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
|
||||
process.stdin.write(raw_bytes)
|
||||
process.stdin.flush()
|
||||
|
||||
|
||||
def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame:
|
||||
return process_vision_frame(vision_frame)
|
||||
|
||||
|
||||
def get_stream_mode(scope : Scope) -> Optional[str]:
|
||||
protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol')
|
||||
|
||||
if protocol_header:
|
||||
for protocol in protocol_header.split(','):
|
||||
protocol = protocol.strip()
|
||||
|
||||
if protocol in [ 'image', 'video' ]:
|
||||
return protocol
|
||||
|
||||
return None
|
||||
|
||||
@@ -70,6 +70,37 @@ def open_ffmpeg(commands : List[Command]) -> subprocess.Popen[bytes]:
|
||||
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
|
||||
|
||||
|
||||
def open_vp8_encoder(width : int, height : int, stream_fps : int, stream_bitrate : str, stream_bufsize : str) -> subprocess.Popen[bytes]:
|
||||
commands = ffmpeg_builder.chain(
|
||||
ffmpeg_builder.use_wallclock_timestamps(),
|
||||
ffmpeg_builder.capture_video(),
|
||||
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
|
||||
ffmpeg_builder.set_input('-'),
|
||||
ffmpeg_builder.set_video_encoder('libvpx'),
|
||||
[ '-deadline', 'realtime' ],
|
||||
[ '-cpu-used', '8' ],
|
||||
ffmpeg_builder.enforce_pixel_format('yuv420p'),
|
||||
[ '-crf', '10' ],
|
||||
ffmpeg_builder.set_video_bitrate(stream_bitrate),
|
||||
ffmpeg_builder.set_video_bufsize(stream_bufsize),
|
||||
ffmpeg_builder.set_keyframe_interval(stream_fps),
|
||||
[ '-error-resilient', '1' ],
|
||||
[ '-lag-in-frames', '0' ],
|
||||
[ '-rc_lookahead', '0' ],
|
||||
[ '-threads', '4' ],
|
||||
ffmpeg_builder.ignore_audio_stream(),
|
||||
ffmpeg_builder.set_output_format('ivf'),
|
||||
ffmpeg_builder.set_output('-')
|
||||
)
|
||||
commands = ffmpeg_builder.run(commands)
|
||||
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
|
||||
|
||||
|
||||
def write_raw_bytes(process : subprocess.Popen[bytes], raw_bytes : bytes) -> None:
|
||||
process.stdin.write(raw_bytes)
|
||||
process.stdin.flush()
|
||||
|
||||
|
||||
def log_debug(process : subprocess.Popen[bytes]) -> None:
|
||||
_, stderr = process.communicate()
|
||||
errors = stderr.decode().split(os.linesep)
|
||||
|
||||
@@ -259,6 +259,30 @@ def capture_video() -> List[Command]:
|
||||
return [ '-f', 'rawvideo', '-pix_fmt', 'rgb24' ]
|
||||
|
||||
|
||||
def use_wallclock_timestamps() -> List[Command]:
|
||||
return [ '-use_wallclock_as_timestamps', '1' ]
|
||||
|
||||
|
||||
def set_video_bitrate(video_bitrate : str) -> List[Command]:
|
||||
return [ '-b:v', video_bitrate, '-maxrate', video_bitrate ]
|
||||
|
||||
|
||||
def set_video_bufsize(video_bufsize : str) -> List[Command]:
|
||||
return [ '-bufsize', video_bufsize ]
|
||||
|
||||
|
||||
def set_keyframe_interval(interval : int) -> List[Command]:
|
||||
return [ '-g', str(interval), '-keyint_min', str(interval) ]
|
||||
|
||||
|
||||
def ignore_audio_stream() -> List[Command]:
|
||||
return [ '-an' ]
|
||||
|
||||
|
||||
def set_output_format(output_format : str) -> List[Command]:
|
||||
return [ '-f', output_format ]
|
||||
|
||||
|
||||
def ignore_video_stream() -> List[Command]:
|
||||
return [ '-vn' ]
|
||||
|
||||
|
||||
+58
-9
@@ -1,11 +1,16 @@
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List, Optional, TypeAlias
|
||||
|
||||
from facefusion import logger
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.filesystem import resolve_relative_path
|
||||
from facefusion.types import DownloadSet
|
||||
|
||||
RtcLib : TypeAlias = ctypes.CDLL
|
||||
|
||||
@@ -60,18 +65,62 @@ class RtcPacketizerInit(ctypes.Structure):
|
||||
]
|
||||
|
||||
|
||||
def get_binary_name() -> str:
|
||||
system = platform.system()
|
||||
|
||||
if system == 'Linux':
|
||||
return 'linux-x64-openssl-h264-vp8-av1-opus-libdatachannel-0.24.1.so'
|
||||
if system == 'Darwin':
|
||||
return 'macos-universal-openssl-h264-vp8-av1-opus-libdatachannel-0.24.1.dylib'
|
||||
if system == 'Windows':
|
||||
return 'windows-x64-openssl-h264-vp8-av1-opus-datachannel-0.24.1.dll'
|
||||
return ''
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_static_download_set() -> Dict[str, DownloadSet]:
|
||||
binary_name = get_binary_name()
|
||||
|
||||
return\
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url('binaries-1.0.0', binary_name + '.hash'),
|
||||
'path': resolve_relative_path('../.assets/binaries/' + binary_name + '.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'datachannel':
|
||||
{
|
||||
'url': resolve_download_url('binaries-1.0.0', binary_name),
|
||||
'path': resolve_relative_path('../.assets/binaries/' + binary_name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def pre_check() -> bool:
|
||||
download_set = create_static_download_set()
|
||||
|
||||
if not conditional_download_hashes(download_set.get('hashes')):
|
||||
return False
|
||||
return conditional_download_sources(download_set.get('sources'))
|
||||
|
||||
|
||||
def find_library() -> Optional[str]:
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
bin_dir = os.path.join(project_root, 'bin')
|
||||
system = platform.system()
|
||||
ext = '.dll' if system == 'Windows' else '.dylib' if system == 'Darwin' else '.so'
|
||||
|
||||
if not os.path.isdir(bin_dir):
|
||||
return None
|
||||
for search_dir in [ resolve_relative_path('../.assets/binaries'), resolve_relative_path('../bin') ]:
|
||||
if not os.path.isdir(search_dir):
|
||||
continue
|
||||
|
||||
ext = '.dll' if os.name == 'nt' else '.so'
|
||||
|
||||
for name in os.listdir(bin_dir):
|
||||
if 'datachannel' in name and name.endswith(ext):
|
||||
return os.path.join(bin_dir, name)
|
||||
for name in os.listdir(search_dir):
|
||||
if 'datachannel' in name and name.endswith(ext):
|
||||
return os.path.join(search_dir, name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -278,6 +278,10 @@ def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame
|
||||
return target_vision_frame
|
||||
|
||||
|
||||
def convert_to_raw_rgb(vision_frame : VisionFrame) -> bytes:
|
||||
return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
|
||||
|
||||
|
||||
def calculate_histogram_difference(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> float:
|
||||
histogram_source = cv2.calcHist([cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
|
||||
histogram_target = cv2.calcHist([cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
|
||||
|
||||
+6
-5
@@ -847,8 +847,8 @@ async function connect() {
|
||||
outputVideo.removeAttribute('src');
|
||||
outputVideo.load();
|
||||
|
||||
var wsUrl = wsBase() + '/stream/rtc';
|
||||
var protocols = ['access_token.' + accessToken];
|
||||
var wsUrl = wsBase() + '/stream';
|
||||
var protocols = ['access_token.' + accessToken, 'video'];
|
||||
var t0 = performance.now();
|
||||
log('ws → ' + wsUrl, 'info');
|
||||
|
||||
@@ -880,8 +880,9 @@ async function connect() {
|
||||
var streamStarted = false;
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
if (typeof event.data === 'string' && !whepUrlFromServer) {
|
||||
whepUrlFromServer = base() + event.data;
|
||||
if (typeof event.data === 'string' && event.data.startsWith('ready:') && !whepUrlFromServer) {
|
||||
var sessionId = event.data.split(':')[1];
|
||||
whepUrlFromServer = base() + '/stream?session_id=' + sessionId;
|
||||
|
||||
if (!streamStarted) {
|
||||
streamStarted = true;
|
||||
@@ -890,7 +891,7 @@ async function connect() {
|
||||
log('stream output started', 'ok');
|
||||
}
|
||||
|
||||
log('stream ready (' + Math.round(performance.now() - t0) + 'ms) — WHEP url: ' + whepUrlFromServer, 'ok');
|
||||
log('stream ready (' + Math.round(performance.now() - t0) + 'ms)', 'ok');
|
||||
|
||||
var tWhep = performance.now();
|
||||
connectWhep(whepUrlFromServer).then(function() {
|
||||
|
||||
Reference in New Issue
Block a user