diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 471694e9..0f8a25b8 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -1,8 +1,12 @@ +from contextlib import asynccontextmanager +from typing import AsyncGenerator + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route, WebSocketRoute +from facefusion import mediamtx from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset from facefusion.apis.endpoints.capabilities import get_capabilities from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics @@ -13,6 +17,14 @@ from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_ from facefusion.apis.middlewares.session import create_session_guard +@asynccontextmanager +async def lifespan(app : Starlette) -> AsyncGenerator[None, None]: + mediamtx.start() + mediamtx.wait_for_ready() + yield + mediamtx.stop() + + def create_api() -> Starlette: session_guard = Middleware(create_session_guard) routes =\ @@ -35,7 +47,7 @@ def create_api() -> Starlette: WebSocketRoute('/stream/whip', websocket_stream_whip, middleware = [ session_guard ]) ] - api = Starlette(routes = routes) + api = Starlette(routes = routes, lifespan = lifespan) api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ]) return api diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 015966f9..441a89fb 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -49,7 +49,7 @@ async def websocket_stream(websocket : WebSocket) -> None: await websocket.close() -def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, audio_write_fd_holder : list) -> None: +def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, audio_write_fd_holder : list, stream_path : str) -> None: encoder = None audio_write_fd = -1 output_deque : Deque[VisionFrame] = deque() @@ -70,17 +70,25 @@ def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_ev output_deque.append(future_done.result()) futures.remove(future_done) + if encoder and encoder.poll() is not None: + stderr_output = encoder.stderr.read() if encoder.stderr else b'' + logger.error('encoder died with code ' + str(encoder.returncode) + ': ' + stderr_output.decode(), __name__) + break + while output_deque: temp_vision_frame = output_deque.popleft() if not encoder: height, width = temp_vision_frame.shape[:2] - encoder, audio_write_fd = create_whip_encoder(width, height, STREAM_FPS, STREAM_QUALITY) + encoder, audio_write_fd = create_whip_encoder(width, height, STREAM_FPS, STREAM_QUALITY, stream_path) audio_write_fd_holder[0] = audio_write_fd logger.info('whip encoder started ' + str(width) + 'x' + str(height), __name__) feed_whip_frame(encoder, temp_vision_frame) + if encoder and not ready_event.is_set() and mediamtx.is_path_ready(stream_path): + ready_event.set() + if capture_frame is None and not output_deque: time.sleep(0.005) @@ -104,28 +112,29 @@ async def websocket_stream_whip(websocket : WebSocket) -> None: await websocket.accept(subprotocol = subprotocol) if source_paths: - mediamtx_process = mediamtx.start() - is_ready = await asyncio.get_running_loop().run_in_executor(None, mediamtx.wait_for_ready) - - if not is_ready: - logger.error('mediamtx failed to start', __name__) - mediamtx.stop(mediamtx_process) - await websocket.close() - return - - logger.info('mediamtx ready', __name__) + stream_path = 'stream/' + session_id + mediamtx.remove_path(stream_path) + mediamtx.add_path(stream_path) + logger.info('mediamtx path added ' + stream_path, __name__) latest_frame_holder : list = [None] audio_write_fd_holder : list = [-1] + whep_sent = False lock = threading.Lock() stop_event = threading.Event() - worker = threading.Thread(target = run_whip_pipeline, args = (latest_frame_holder, lock, stop_event, audio_write_fd_holder), daemon = True) + ready_event = threading.Event() + worker = threading.Thread(target = run_whip_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, audio_write_fd_holder, stream_path), daemon = True) worker.start() try: while True: message = await websocket.receive() + if not whep_sent and ready_event.is_set(): + whep_url = mediamtx.get_whep_url(stream_path) + await websocket.send_text(whep_url) + whep_sent = True + if message.get('bytes'): data = message.get('bytes') @@ -143,10 +152,9 @@ async def websocket_stream_whip(websocket : WebSocket) -> None: logger.error(str(exception), __name__) stop_event.set() - worker.join(timeout = 10) - - if mediamtx_process: - mediamtx.stop(mediamtx_process) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, worker.join, 10) + mediamtx.remove_path(stream_path) return await websocket.close() diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 3fd69023..adf2d37b 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -27,10 +27,10 @@ def create_dtls_certificate() -> None: ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL) -def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> Tuple[subprocess.Popen[bytes], int]: +def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int, stream_path : str) -> Tuple[subprocess.Popen[bytes], int]: create_dtls_certificate() audio_read_fd, audio_write_fd = os.pipe() - whip_url = mediamtx.get_whip_url() + whip_url = mediamtx.get_whip_url(stream_path) commands = ffmpeg_builder.chain( [ '-use_wallclock_as_timestamps', '1' ], ffmpeg_builder.capture_video(), diff --git a/facefusion/mediamtx.py b/facefusion/mediamtx.py index 0389a9cc..b44bbc10 100644 --- a/facefusion/mediamtx.py +++ b/facefusion/mediamtx.py @@ -9,13 +9,17 @@ import httpx MEDIAMTX_WHIP_PORT : int = 8889 MEDIAMTX_API_PORT : int = 9997 -MEDIAMTX_PATH : str = 'stream' MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'mediamtx.yml') MEDIAMTX_FALLBACK_BINARY : str = '/home/henry/local/bin/mediamtx' +MEDIAMTX_PROCESS : Optional[subprocess.Popen[bytes]] = None -def get_whip_url() -> str: - return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + MEDIAMTX_PATH + '/whip' +def get_whip_url(stream_path : str) -> str: + return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + stream_path + '/whip' + + +def get_whep_url(stream_path : str) -> str: + return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + stream_path + '/whep' def get_api_url() -> str: @@ -30,20 +34,25 @@ def resolve_binary() -> str: return MEDIAMTX_FALLBACK_BINARY -def start() -> Optional[subprocess.Popen[bytes]]: +def start() -> None: + global MEDIAMTX_PROCESS + stop_stale() mediamtx_binary = resolve_binary() - - return subprocess.Popen( + MEDIAMTX_PROCESS = subprocess.Popen( [ mediamtx_binary, MEDIAMTX_CONFIG ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL ) -def stop(process : subprocess.Popen[bytes]) -> None: - process.terminate() - process.wait() +def stop() -> None: + global MEDIAMTX_PROCESS + + if MEDIAMTX_PROCESS: + MEDIAMTX_PROCESS.terminate() + MEDIAMTX_PROCESS.wait() + MEDIAMTX_PROCESS = None def stop_stale() -> None: @@ -66,3 +75,30 @@ def wait_for_ready() -> bool: pass time.sleep(0.5) return False + + +def is_path_ready(stream_path : str) -> bool: + api_url = get_api_url() + '/v3/paths/get/' + stream_path + + try: + response = httpx.get(api_url, timeout = 1) + + if response.status_code == 200: + return response.json().get('ready', False) + except Exception: + pass + return False + + +def add_path(stream_path : str) -> bool: + api_url = get_api_url() + '/v3/config/paths/add/' + stream_path + response = httpx.post(api_url, json = {}, timeout = 5) + + return response.status_code == 200 + + +def remove_path(stream_path : str) -> bool: + api_url = get_api_url() + '/v3/config/paths/delete/' + stream_path + response = httpx.delete(api_url, timeout = 5) + + return response.status_code == 200 diff --git a/mediamtx.yml b/mediamtx.yml index d254215d..51ea060c 100644 --- a/mediamtx.yml +++ b/mediamtx.yml @@ -7,4 +7,3 @@ webrtcAddress: :8889 api: yes apiAddress: :9997 paths: - stream: diff --git a/test_whip_stream.html b/test_whip_stream.html new file mode 100644 index 00000000..d7e35d98 --- /dev/null +++ b/test_whip_stream.html @@ -0,0 +1,827 @@ + + + + +whip_stream test + + + + +

whip_stream — face swap via ffmpeg WHIP + mediamtx

+ +
+ +
+ 1 + session +
+ + + +
+
+ +
+ 2 + source face +
+ + source +
+
+ +
+ 3 + video source +
+ + + + +
+
+ +
+ 4 + options +
+ + + +
+
+ +
+ 5 + whip stream +
+ + +
+
+ +
+ +
+
+ processed output (webrtc via WHEP) + + +
+
+
ws
+
rtc
+
ice
+
codec
+
res
+
fps in
+
kbps in
+
frames
+
fps out
+
sent
+
up
+
+
+ +
+ 0:00 + + 0:00 +
+ +
+ + + +