diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 7c986ae3..015966f9 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -12,7 +12,8 @@ from starlette.websockets import WebSocket from facefusion import logger, session_context, session_manager, state_manager from facefusion.apis.api_helper import get_sec_websocket_protocol from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_whip_encoder, create_whip_encoder, feed_whip_audio, feed_whip_frame, process_stream_frame, start_mediamtx, stop_mediamtx, wait_for_mediamtx +from facefusion import mediamtx +from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_whip_encoder, create_whip_encoder, feed_whip_audio, feed_whip_frame, process_stream_frame from facefusion.streamer import process_vision_frame from facefusion.types import VisionFrame @@ -103,12 +104,12 @@ async def websocket_stream_whip(websocket : WebSocket) -> None: await websocket.accept(subprotocol = subprotocol) if source_paths: - mediamtx = start_mediamtx() - is_ready = await asyncio.get_running_loop().run_in_executor(None, wait_for_mediamtx) + mediamtx_process = mediamtx.start() + is_ready = await asyncio.get_running_loop().run_in_executor(None, mediamtx.wait_for_ready) if not is_ready: logger.error('mediamtx failed to start', __name__) - stop_mediamtx(mediamtx) + mediamtx.stop(mediamtx_process) await websocket.close() return @@ -144,8 +145,8 @@ async def websocket_stream_whip(websocket : WebSocket) -> None: stop_event.set() worker.join(timeout = 10) - if mediamtx: - stop_mediamtx(mediamtx) + if mediamtx_process: + mediamtx.stop(mediamtx_process) return await websocket.close() diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index ddd459e1..3fd69023 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -1,23 +1,17 @@ import os -import shutil import subprocess import tempfile -import time -from typing import Optional, Tuple +from typing import Tuple import cv2 -import requests -from facefusion import ffmpeg_builder +from facefusion import ffmpeg_builder, mediamtx from facefusion.streamer import process_vision_frame from facefusion.types import VisionFrame STREAM_FPS : int = 30 STREAM_QUALITY : int = 45 STREAM_AUDIO_RATE : int = 48000 -MEDIAMTX_WHIP_PORT : int = 8889 -MEDIAMTX_PATH : str = 'stream' -MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'mediamtx.yml') DTLS_CERT_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_cert.pem') DTLS_KEY_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_key.pem') @@ -36,7 +30,7 @@ def create_dtls_certificate() -> None: def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> Tuple[subprocess.Popen[bytes], int]: create_dtls_certificate() audio_read_fd, audio_write_fd = os.pipe() - whip_url = 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + MEDIAMTX_PATH + '/whip' + whip_url = mediamtx.get_whip_url() commands = ffmpeg_builder.chain( [ '-use_wallclock_as_timestamps', '1' ], ffmpeg_builder.capture_video(), @@ -65,45 +59,6 @@ def create_whip_encoder(width : int, height : int, stream_fps : int, stream_qual return process, audio_write_fd -def start_mediamtx() -> Optional[subprocess.Popen[bytes]]: - stop_stale_mediamtx() - mediamtx_path = shutil.which('mediamtx') - - if not mediamtx_path: - mediamtx_path = '/home/henry/local/bin/mediamtx' - - return subprocess.Popen( - [ mediamtx_path, MEDIAMTX_CONFIG ], - stdout = subprocess.DEVNULL, - stderr = subprocess.DEVNULL - ) - - -def stop_stale_mediamtx() -> None: - subprocess.run([ 'fuser', '-k', str(MEDIAMTX_WHIP_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - subprocess.run([ 'fuser', '-k', '9997/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) - time.sleep(1) - - -def wait_for_mediamtx() -> bool: - for _ in range(10): - try: - response = requests.get('http://localhost:9997/v3/paths/list', timeout = 1) - - if response.status_code == 200: - return True - except Exception: - pass - time.sleep(0.5) - return False - - -def stop_mediamtx(process : subprocess.Popen[bytes]) -> None: - process.terminate() - process.wait() - - def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None: raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes() process.stdin.write(raw_bytes) diff --git a/facefusion/mediamtx.py b/facefusion/mediamtx.py new file mode 100644 index 00000000..0389a9cc --- /dev/null +++ b/facefusion/mediamtx.py @@ -0,0 +1,68 @@ +import os +import shutil +import subprocess +import time +from typing import Optional + +import httpx + + +MEDIAMTX_WHIP_PORT : int = 8889 +MEDIAMTX_API_PORT : int = 9997 +MEDIAMTX_PATH : str = 'stream' +MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'mediamtx.yml') +MEDIAMTX_FALLBACK_BINARY : str = '/home/henry/local/bin/mediamtx' + + +def get_whip_url() -> str: + return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + MEDIAMTX_PATH + '/whip' + + +def get_api_url() -> str: + return 'http://localhost:' + str(MEDIAMTX_API_PORT) + + +def resolve_binary() -> str: + mediamtx_path = shutil.which('mediamtx') + + if mediamtx_path: + return mediamtx_path + return MEDIAMTX_FALLBACK_BINARY + + +def start() -> Optional[subprocess.Popen[bytes]]: + stop_stale() + mediamtx_binary = resolve_binary() + + return subprocess.Popen( + [ mediamtx_binary, MEDIAMTX_CONFIG ], + stdout = subprocess.DEVNULL, + stderr = subprocess.DEVNULL + ) + + +def stop(process : subprocess.Popen[bytes]) -> None: + process.terminate() + process.wait() + + +def stop_stale() -> None: + subprocess.run([ 'fuser', '-k', str(MEDIAMTX_WHIP_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + subprocess.run([ 'fuser', '-k', str(MEDIAMTX_API_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) + time.sleep(1) + + +def wait_for_ready() -> bool: + api_url = get_api_url() + '/v3/paths/list' + + for _ in range(10): + try: + response = httpx.get(api_url, timeout = 1) + + if response.status_code == 200: + return True + except Exception: + pass + time.sleep(0.5) + return False