diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index e99fd624..e5f0aeff 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -14,7 +14,7 @@ from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state from facefusion import logger -from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_audio, websocket_stream_live, websocket_stream_mjpeg, websocket_stream_rtc, websocket_stream_whip, websocket_stream_whip_dc, websocket_stream_whip_py +from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_audio, websocket_stream_live, websocket_stream_mjpeg, websocket_stream_rtc, websocket_stream_rtc_relay, websocket_stream_whip, websocket_stream_whip_dc, websocket_stream_whip_py from facefusion.apis.middlewares.session import create_session_guard @@ -87,7 +87,8 @@ def create_api() -> Starlette: WebSocketRoute('/stream/whip-py', websocket_stream_whip_py, middleware = [ session_guard ]), WebSocketRoute('/stream/whip-dc', websocket_stream_whip_dc, middleware = [ session_guard ]), WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ]), -WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ]), + WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ]), + WebSocketRoute('/stream/rtc-relay', websocket_stream_rtc_relay, middleware = [ session_guard ]), WebSocketRoute('/stream/mjpeg', websocket_stream_mjpeg, middleware = [ session_guard ]), WebSocketRoute('/stream/audio', websocket_stream_audio, middleware = [ session_guard ]) ] diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 3fa3b01f..948a4df0 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -784,6 +784,80 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s encoder.wait(timeout = 5) +async def websocket_stream_rtc_relay(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + access_token = extract_access_token(websocket.scope) + session_id = session_manager.find_session_id(access_token) + + session_context.set_session_id(session_id) + source_paths = state_manager.get_item('source_paths') + + await websocket.accept(subprotocol = subprotocol) + + if source_paths: + from facefusion import rtc + import socket as sock + stream_path = 'stream/' + session_id + rtp_port = rtc.create_rtp_session(stream_path) + whep_url = 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep' + + audio_sock = sock.socket(sock.AF_INET, sock.SOCK_DGRAM) + relay_addr = ('127.0.0.1', rtp_port) + + latest_frame_holder : list = [None] + whep_sent = False + lock = threading.Lock() + stop_event = threading.Event() + ready_event = threading.Event() + worker = threading.Thread(target = run_h264_dc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, 'relay', stream_path, rtp_port), daemon = True) + worker.start() + + try: + while True: + message = await websocket.receive() + + if not whep_sent and ready_event.is_set(): + await websocket.send_text(whep_url) + whep_sent = True + + if message.get('bytes'): + data = message.get('bytes') + + if data[:2] == JPEG_MAGIC: + frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR) + + if numpy.any(frame): + with lock: + latest_frame_holder[0] = frame + + if data[:2] != JPEG_MAGIC: + rtc.init_opus_encoder() + + with rtc.audio_lock: + rtc.audio_buffer.extend(data) + needed = rtc.OPUS_FRAME_SAMPLES * 2 * 2 + + while len(rtc.audio_buffer) >= needed: + chunk = bytes(rtc.audio_buffer[:needed]) + del rtc.audio_buffer[:needed] + opus_pkt = rtc.encode_opus_frame(chunk) + + if opus_pkt: + audio_sock.sendto(b'\x02' + opus_pkt, relay_addr) + + except Exception as exception: + logger.error(str(exception), __name__) + + stop_event.set() + audio_sock.close() + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, worker.join, 10) + rtc.destroy_session(stream_path) + return + + await websocket.close() + + async def websocket_stream_rtc(websocket : WebSocket) -> None: subprotocol = get_sec_websocket_protocol(websocket.scope) access_token = extract_access_token(websocket.scope) diff --git a/facefusion/rtc.py b/facefusion/rtc.py index 2d888fd2..7c1f322d 100644 --- a/facefusion/rtc.py +++ b/facefusion/rtc.py @@ -255,14 +255,53 @@ def run_rtp_forwarder(stream_path : str) -> None: try: data, addr = rtp_fd.recvfrom(262144) - if not data: + if len(data) < 2: continue - send_to_viewers(stream_path, data) + tag = data[0] + payload = data[1:] + + if tag == 0x01: + send_to_viewers(stream_path, payload) + if tag == 0x02: + send_audio_to_viewers(stream_path, payload) except Exception: continue +def send_audio_to_viewers(stream_path : str, opus_data : bytes) -> None: + global audio_pts + + session = sessions.get(stream_path) + + if not session: + return + + viewers = session.get('viewers') + + if not viewers: + return + + buf = ctypes.create_string_buffer(opus_data) + + for viewer in viewers: + if not viewer.get('connected'): + continue + + audio_track_id = viewer.get('audio_track') + + if not audio_track_id: + continue + + if not lib.rtcIsOpen(audio_track_id): + continue + + lib.rtcSetTrackRtpTimestamp(audio_track_id, audio_pts & 0xFFFFFFFF) + lib.rtcSendMessage(audio_track_id, buf, len(opus_data)) + + audio_pts += OPUS_FRAME_SAMPLES + + send_start_time : float = 0 audio_pts : int = 0 opus_enc = None diff --git a/test_stream.html b/test_stream.html index 879ed8da..bc193d16 100644 --- a/test_stream.html +++ b/test_stream.html @@ -148,7 +148,8 @@ - + + @@ -323,7 +324,8 @@ var MODE_CONFIG = { 'whip-datachannel': { wsPath: '/stream/whip-dc', playback: 'whep' }, 'ws-fmp4': { wsPath: '/stream/live', playback: 'mse' }, 'datachannel-direct': { wsPath: '/stream/rtc', playback: 'whep' }, -'ws-mjpeg': { wsPath: '/stream/mjpeg', playback: 'mjpeg' } + 'datachannel-relay-py': { wsPath: '/stream/rtc-relay', playback: 'whep' }, + 'ws-mjpeg': { wsPath: '/stream/mjpeg', playback: 'mjpeg' } }; function getMode() {