mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 09:26:02 +02:00
move to whip
This commit is contained in:
@@ -9,7 +9,7 @@ from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_live
|
||||
from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_whip
|
||||
from facefusion.apis.middlewares.session import create_session_guard
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def create_api() -> Starlette:
|
||||
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ])
|
||||
WebSocketRoute('/stream/whip', websocket_stream_whip, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
api = Starlette(routes = routes)
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Deque
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion import session_context, session_manager, state_manager
|
||||
from facefusion import logger, session_context, session_manager, state_manager
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.apis.session_helper import extract_access_token
|
||||
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_stream_encoder, create_stream_encoder, encode_stream_frame, process_stream_frame, read_stream_output
|
||||
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_whip_encoder, create_whip_encoder, feed_whip_frame, process_stream_frame, start_mediamtx, stop_mediamtx, wait_for_mediamtx
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import VisionFrame
|
||||
|
||||
|
||||
async def websocket_stream(websocket : WebSocket) -> None:
|
||||
@@ -41,7 +46,49 @@ async def websocket_stream(websocket : WebSocket) -> None:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def websocket_stream_live(websocket : WebSocket) -> None:
|
||||
def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event) -> None:
|
||||
encoder = None
|
||||
output_deque : Deque[VisionFrame] = deque()
|
||||
|
||||
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
|
||||
futures = []
|
||||
|
||||
while not stop_event.is_set():
|
||||
with lock:
|
||||
capture_frame = latest_frame_holder[0]
|
||||
latest_frame_holder[0] = None
|
||||
|
||||
if capture_frame is not None:
|
||||
future = executor.submit(process_stream_frame, capture_frame)
|
||||
futures.append(future)
|
||||
|
||||
for future_done in [ future for future in futures if future.done() ]:
|
||||
output_deque.append(future_done.result())
|
||||
futures.remove(future_done)
|
||||
|
||||
while output_deque:
|
||||
temp_vision_frame = output_deque.popleft()
|
||||
|
||||
if not encoder:
|
||||
height, width = temp_vision_frame.shape[:2]
|
||||
encoder = create_whip_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
|
||||
logger.info('whip encoder started ' + str(width) + 'x' + str(height), __name__)
|
||||
|
||||
feed_whip_frame(encoder, temp_vision_frame)
|
||||
|
||||
if capture_frame is None and not output_deque:
|
||||
time.sleep(0.005)
|
||||
|
||||
if encoder:
|
||||
stderr_output = encoder.stderr.read() if encoder.stderr else b''
|
||||
|
||||
if stderr_output:
|
||||
logger.error('ffmpeg: ' + stderr_output.decode(), __name__)
|
||||
|
||||
close_whip_encoder(encoder)
|
||||
|
||||
|
||||
async def websocket_stream_whip(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
access_token = extract_access_token(websocket.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
@@ -52,35 +99,40 @@ async def websocket_stream_live(websocket : WebSocket) -> None:
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
if source_paths:
|
||||
encoder = None
|
||||
reader_thread = None
|
||||
output_chunks = []
|
||||
mediamtx = start_mediamtx()
|
||||
is_ready = await asyncio.get_running_loop().run_in_executor(None, wait_for_mediamtx)
|
||||
|
||||
if not is_ready:
|
||||
logger.error('mediamtx failed to start', __name__)
|
||||
stop_mediamtx(mediamtx)
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
logger.info('mediamtx ready', __name__)
|
||||
|
||||
latest_frame_holder : list = [None]
|
||||
lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
worker = threading.Thread(target = run_whip_pipeline, args = (latest_frame_holder, lock, stop_event), daemon = True)
|
||||
worker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
image_buffer = await websocket.receive_bytes()
|
||||
target_vision_frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
frame = cv2.imdecode(numpy.frombuffer(image_buffer, numpy.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if numpy.any(target_vision_frame):
|
||||
temp_vision_frame = await asyncio.get_running_loop().run_in_executor(None, process_stream_frame, target_vision_frame)
|
||||
if numpy.any(frame):
|
||||
with lock:
|
||||
latest_frame_holder[0] = frame
|
||||
|
||||
if not encoder:
|
||||
height, width = temp_vision_frame.shape[:2]
|
||||
encoder = create_stream_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
|
||||
reader_thread = threading.Thread(target = read_stream_output, args = (encoder, output_chunks, lock), daemon = True)
|
||||
reader_thread.start()
|
||||
except Exception as exception:
|
||||
logger.error(str(exception), __name__)
|
||||
|
||||
encoded_bytes = encode_stream_frame(encoder, temp_vision_frame, output_chunks, lock)
|
||||
stop_event.set()
|
||||
worker.join(timeout = 10)
|
||||
|
||||
if encoded_bytes:
|
||||
await websocket.send_bytes(encoded_bytes)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if encoder:
|
||||
close_stream_encoder(encoder)
|
||||
if mediamtx:
|
||||
stop_mediamtx(mediamtx)
|
||||
return
|
||||
|
||||
await websocket.close()
|
||||
|
||||
@@ -1,66 +1,114 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
|
||||
from facefusion import ffmpeg_builder
|
||||
from facefusion.ffmpeg import open_ffmpeg
|
||||
from facefusion.streamer import process_vision_frame
|
||||
from facefusion.types import VisionFrame
|
||||
|
||||
STREAM_FPS : int = 30
|
||||
STREAM_QUALITY : int = 80
|
||||
STREAM_QUALITY : int = 45
|
||||
MEDIAMTX_WHIP_PORT : int = 8889
|
||||
MEDIAMTX_PATH : str = 'stream'
|
||||
MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'mediamtx.yml')
|
||||
DTLS_CERT_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_cert.pem')
|
||||
DTLS_KEY_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_key.pem')
|
||||
|
||||
|
||||
def create_stream_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
|
||||
def create_dtls_certificate() -> None:
|
||||
if os.path.isfile(DTLS_CERT_FILE) and os.path.isfile(DTLS_KEY_FILE):
|
||||
return
|
||||
|
||||
subprocess.run([
|
||||
'openssl', 'req', '-x509', '-newkey', 'ec', '-pkeyopt', 'ec_paramgen_curve:prime256v1',
|
||||
'-keyout', DTLS_KEY_FILE, '-out', DTLS_CERT_FILE,
|
||||
'-days', '365', '-nodes', '-subj', '/CN=facefusion'
|
||||
], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
|
||||
|
||||
|
||||
def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
|
||||
create_dtls_certificate()
|
||||
whip_url = 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + MEDIAMTX_PATH + '/whip'
|
||||
commands = ffmpeg_builder.chain(
|
||||
[ '-use_wallclock_as_timestamps', '1' ],
|
||||
ffmpeg_builder.capture_video(),
|
||||
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
|
||||
ffmpeg_builder.set_input_fps(stream_fps),
|
||||
ffmpeg_builder.set_input('-'),
|
||||
[ '-f', 'lavfi', '-i', 'anullsrc=r=48000:cl=stereo' ],
|
||||
ffmpeg_builder.set_video_encoder('libx264'),
|
||||
ffmpeg_builder.set_video_quality('libx264', stream_quality),
|
||||
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
|
||||
[ '-pix_fmt', 'yuv420p' ],
|
||||
[ '-profile:v', 'baseline' ],
|
||||
[ '-tune', 'zerolatency' ],
|
||||
[ '-maxrate', '4000k' ],
|
||||
[ '-bufsize', '8000k' ],
|
||||
[ '-maxrate', '1500k' ],
|
||||
[ '-bufsize', '3000k' ],
|
||||
[ '-g', str(stream_fps) ],
|
||||
[ '-f', 'mp4' ],
|
||||
[ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ],
|
||||
ffmpeg_builder.set_output('-')
|
||||
[ '-c:a', 'libopus' ],
|
||||
[ '-f', 'whip' ],
|
||||
[ '-cert_file', DTLS_CERT_FILE ],
|
||||
[ '-key_file', DTLS_KEY_FILE ],
|
||||
ffmpeg_builder.set_output(whip_url)
|
||||
)
|
||||
return open_ffmpeg(commands)
|
||||
commands = ffmpeg_builder.run(commands)
|
||||
return subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE)
|
||||
|
||||
|
||||
def read_stream_output(process : subprocess.Popen[bytes], output_chunks : list, lock : threading.Lock) -> None:
|
||||
while True:
|
||||
chunk = process.stdout.read(4096)
|
||||
def start_mediamtx() -> Optional[subprocess.Popen[bytes]]:
|
||||
stop_stale_mediamtx()
|
||||
mediamtx_path = shutil.which('mediamtx')
|
||||
|
||||
if not chunk:
|
||||
break
|
||||
if not mediamtx_path:
|
||||
mediamtx_path = '/home/henry/local/bin/mediamtx'
|
||||
|
||||
with lock:
|
||||
output_chunks.append(chunk)
|
||||
return subprocess.Popen(
|
||||
[ mediamtx_path, MEDIAMTX_CONFIG ],
|
||||
stdout = subprocess.DEVNULL,
|
||||
stderr = subprocess.DEVNULL
|
||||
)
|
||||
|
||||
|
||||
def encode_stream_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame, output_chunks : list, lock : threading.Lock) -> Optional[bytes]:
|
||||
def stop_stale_mediamtx() -> None:
|
||||
subprocess.run([ 'fuser', '-k', str(MEDIAMTX_WHIP_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
|
||||
subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
|
||||
subprocess.run([ 'fuser', '-k', '9997/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def wait_for_mediamtx() -> bool:
|
||||
for _ in range(10):
|
||||
try:
|
||||
response = requests.get('http://localhost:9997/v3/paths/list', timeout = 1)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def stop_mediamtx(process : subprocess.Popen[bytes]) -> None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
|
||||
def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None:
|
||||
raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
|
||||
process.stdin.write(raw_bytes)
|
||||
process.stdin.flush()
|
||||
|
||||
with lock:
|
||||
if output_chunks:
|
||||
encoded_bytes = b''.join(output_chunks)
|
||||
output_chunks.clear()
|
||||
return encoded_bytes
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def close_stream_encoder(process : subprocess.Popen[bytes]) -> None:
|
||||
def close_whip_encoder(process : subprocess.Popen[bytes]) -> None:
|
||||
process.stdin.close()
|
||||
process.wait()
|
||||
process.terminate()
|
||||
process.wait(timeout = 5)
|
||||
|
||||
|
||||
def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame:
|
||||
|
||||
@@ -99,7 +99,7 @@ def test_stream_image(test_client : TestClient) -> None:
|
||||
assert output_vision_frame.shape == (1024, 1024, 3)
|
||||
|
||||
|
||||
def test_stream_live(test_client : TestClient) -> None:
|
||||
def test_stream_whip(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
@@ -127,11 +127,10 @@ def test_stream_live(test_client : TestClient) -> None:
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
with test_client.websocket_connect('/stream/live', subprotocols =
|
||||
with test_client.websocket_connect('/stream/whip', subprotocols =
|
||||
[
|
||||
'access_token.' + access_token
|
||||
]) as websocket:
|
||||
websocket.send_bytes(source_content)
|
||||
output_bytes = websocket.receive_bytes()
|
||||
|
||||
assert len(output_bytes) > 0
|
||||
assert True
|
||||
|
||||
Reference in New Issue
Block a user