shrink down to the release candidate

This commit is contained in:
henryruhs
2026-03-25 21:05:53 +01:00
parent 47b48e0de5
commit 28ded002fc
15 changed files with 124 additions and 3507 deletions
-110
View File
@@ -1,110 +0,0 @@
# Compiling libdatachannel
Prebuilt DLLs from OBS or pip lack VP8 support. We compile from source to get all codecs (H264, VP8, AV1, Opus).
## Source
```
git clone --depth 1 --recurse-submodules https://github.com/paullouisageneau/libdatachannel.git
cd libdatachannel
```
## Windows
Requirements: Visual Studio Build Tools 2019+ with C++ workload, cmake, ninja (available via conda).
```cmd
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF
cmake --build build --config Release
```
Output: `build/datachannel.dll`
Rename to: `windows-x64-openssl-h264-vp8-av1-opus-datachannel-<version>.dll`
Place in: `bin/`
## Linux
Requirements: gcc/g++, cmake, ninja-build, libssl-dev.
```bash
sudo apt install build-essential cmake ninja-build libssl-dev
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF
cmake --build build --config Release
```
Output: `build/libdatachannel.so`
Rename to: `linux-x64-openssl-h264-vp8-av1-opus-libdatachannel-<version>.so`
Install to: `/usr/local/lib/` or project `bin/`
If installed to a custom path, run `sudo ldconfig` or set `LD_LIBRARY_PATH`.
## macOS
Requirements: Xcode Command Line Tools, cmake, ninja.
```bash
xcode-select --install
brew install cmake ninja
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF
cmake --build build --config Release
```
For universal binary (arm64 + x86_64):
```bash
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DNO_WEBSOCKET=ON -DNO_EXAMPLES=ON -DNO_TESTS=ON -DUSE_NICE=OFF -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
cmake --build build --config Release
```
Output: `build/libdatachannel.dylib`
Rename to: `macos-universal-openssl-h264-vp8-av1-opus-libdatachannel-<version>.dylib`
Install to: `/usr/local/lib/` or project `bin/`
## Naming convention
```
<os>-<arch>-<tls>-<codecs>-datachannel-<version>.<ext>
```
- os: `windows`, `linux`, `macos`
- arch: `x64`, `arm64`, `universal`
- tls: `openssl` (default), `gnutls`, `mbedtls`
- codecs: supported packetizers, e.g. `h264-vp8-av1-opus`
- version: libdatachannel version, e.g. `0.24.1`
## Verifying the build
```python
import ctypes
lib = ctypes.CDLL('path/to/datachannel.dll')
for fn in ['rtcSetH264Packetizer', 'rtcSetVP8Packetizer', 'rtcSetAV1Packetizer', 'rtcSetOpusPacketizer']:
try:
getattr(lib, fn)
print(f'{fn}: OK')
except AttributeError:
print(f'{fn}: MISSING')
```
## CMake flags reference
| Flag | Default | Purpose |
|---|---|---|
| `NO_WEBSOCKET` | OFF | Disable WebSocket support (not needed) |
| `NO_MEDIA` | OFF | Disable media transport (must be OFF for codecs) |
| `NO_EXAMPLES` | OFF | Skip building examples |
| `NO_TESTS` | OFF | Skip building tests |
| `USE_NICE` | OFF | Use libnice instead of libjuice for ICE |
| `USE_GNUTLS` | OFF | Use GnuTLS instead of OpenSSL |
| `USE_MBEDTLS` | OFF | Use Mbed TLS instead of OpenSSL |
## Runtime dependencies
- **libopus**: Required for audio encoding. Install via `conda install -c conda-forge libopus` (Windows) or `apt install libopus-dev` (Linux) or `brew install opus` (macOS).
- **OpenSSL**: Usually bundled or system-provided. On Windows, conda provides it.
+45 -164
View File
@@ -2,7 +2,6 @@ import os
import platform
import signal
import subprocess
import sys
import time
import httpx
@@ -12,16 +11,8 @@ API_PORT : int = 8400
HTML_FILE : str = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_stream.html')
SOURCE_FILE : str = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.assets', 'examples', 'source.jpg')
def is_windows() -> bool:
return platform.system().lower() == 'windows'
def is_macos() -> bool:
return platform.system().lower() == 'darwin'
if is_windows():
if platform.system().lower() == 'windows':
VIDEO_FILE : str = 'C:\\Users\\info\\Downloads\\face8k.mp4'
elif is_macos():
VIDEO_FILE : str = '/Users/henry/Downloads/copy_face_instant.mp4'
else:
VIDEO_FILE : str = '/home/henry/Documents/examples/download.mp4'
@@ -32,25 +23,12 @@ def safe_print(text : str) -> None:
except UnicodeEncodeError:
print(text.encode('ascii', errors='replace').decode('ascii'))
_ALL_MODES =\
[
'whip-mediamtx',
'whip-python',
'whip-datachannel',
'ws-fmp4',
'datachannel-direct',
'datachannel-relay-py',
'ws-mjpeg'
]
MODES = [ m for m in _ALL_MODES if not (is_macos() and m == 'whip-mediamtx') ]
def start_api() -> subprocess.Popen:
env = os.environ.copy()
python_cmd = 'python' if is_windows() else 'python3'
python_cmd = 'python' if platform.system().lower() == 'windows' else 'python3'
if not is_windows() and not is_macos():
if platform.system().lower() != 'windows':
env['LD_LIBRARY_PATH'] = '/home/henry/local/lib:' + env.get('LD_LIBRARY_PATH', '')
proc = subprocess.Popen(
@@ -81,7 +59,7 @@ def wait_for_api(timeout : int = 60) -> bool:
def stop_api(proc : subprocess.Popen) -> None:
if is_windows():
if platform.system().lower() == 'windows':
proc.terminate()
else:
proc.send_signal(signal.SIGTERM)
@@ -95,64 +73,32 @@ def stop_api(proc : subprocess.Popen) -> None:
time.sleep(1)
def kill_port_windows(port : int) -> None:
result = subprocess.run(
[ 'netstat', '-ano' ],
capture_output = True, text = True
)
for line in result.stdout.splitlines():
if ':' + str(port) + ' ' in line and ('LISTENING' in line or 'ESTABLISHED' in line):
parts = line.split()
pid = parts[-1]
if pid.isdigit() and int(pid) > 0:
subprocess.run([ 'taskkill', '/F', '/PID', pid ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
def kill_port_macos(port : int) -> None:
pids = set()
for proto in [ 'tcp', 'udp' ]:
result = subprocess.run(
[ 'lsof', '-ti', proto + ':' + str(port) ],
capture_output = True, text = True
)
for pid in result.stdout.split():
if pid.isdigit():
pids.add(pid)
for pid in pids:
subprocess.run([ 'kill', '-9', pid ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
def kill_stale() -> None:
ports = [ API_PORT, 8889, 8189, 9997, 8890, 8891, 8892 ]
ports = [ API_PORT ]
if is_windows():
if platform.system().lower() == 'windows':
for port in ports:
kill_port_windows(port)
elif is_macos():
for port in ports:
kill_port_macos(port)
result = subprocess.run([ 'netstat', '-ano' ], capture_output = True, text = True)
for line in result.stdout.splitlines():
if ':' + str(port) + ' ' in line and ('LISTENING' in line or 'ESTABLISHED' in line):
parts = line.split()
pid = parts[-1]
if pid.isdigit() and int(pid) > 0:
subprocess.run([ 'taskkill', '/F', '/PID', pid ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
else:
subprocess.run([ 'fuser', '-k', str(API_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '8889/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '9997/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '8890/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '8891/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '8892/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
for port in ports:
subprocess.run([ 'fuser', '-k', str(port) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
time.sleep(2)
def test_mode(mode : str) -> dict:
result = {'mode': mode, 'session': False, 'source': False, 'video': False, 'ws_open': False, 'stream_ready': False, 'playback': False, 'error': None}
def test_rtc() -> dict:
result = {'session': False, 'source': False, 'video': False, 'ws_open': False, 'stream_ready': False, 'playback': False, 'error': None}
print('\n' + '=' * 60)
print('TESTING: ' + mode)
print('TESTING: libdatachannel direct (RTC)')
print('=' * 60)
kill_stale()
@@ -236,11 +182,7 @@ def test_mode(mode : str) -> dict:
stop_api(api_proc)
return result
print(' video OK, selecting mode: ' + mode)
page.select_option('#streamMode', mode)
time.sleep(0.5)
print(' starting stream...')
print(' video OK, starting stream...')
for _ in range(10):
time.sleep(1)
@@ -283,68 +225,17 @@ def test_mode(mode : str) -> dict:
if 'stream ready' in log_text or 'WHEP' in log_text:
result['stream_ready'] = True
if mode == 'ws-mjpeg':
result['stream_ready'] = True
try:
has_img = page.evaluate('!!document.getElementById("outputVideo")._mjpegImg && !!document.getElementById("outputVideo")._mjpegImg.src')
if has_img:
result['playback'] = True
print(' [' + str(i) + 's] MJPEG receiving frames')
break
except Exception:
pass
if mode == 'ws-fmp4':
if 'MSE source buffer ready' in log_text:
result['stream_ready'] = True
try:
mse_info = page.evaluate('''() => {
var v = document.getElementById("outputVideo");
var ms = v._mediaSource || window.mediaSource;
var buf = (v.buffered && v.buffered.length > 0) ? v.buffered.end(0) : 0;
return { time: v.currentTime, buffered: buf, readyState: v.readyState, networkState: v.networkState };
}''')
buffered = mse_info.get('buffered', 0)
if buffered > 0 or mse_info.get('time', 0) > 0:
result['playback'] = True
print(' [' + str(i) + 's] MSE buffered=' + str(round(buffered, 2)) + ' time=' + str(round(mse_info.get('time', 0), 2)))
break
if i % 5 == 0:
print(' [' + str(i) + 's] MSE: ' + str(mse_info))
except Exception:
pass
else:
try:
frames_val = int(frames_stat) if frames_stat and frames_stat != '--' else 0
except ValueError:
frames_val = 0
if frames_val > 0:
result['playback'] = True
print(' [' + str(i) + 's] frames=' + str(frames_val) + ' fps=' + fps_stat + ' rtc=' + rtc_stat)
break
try:
rtc_stats = page.evaluate('''() => {
if (!window.pc) return '';
return pc.getStats().then(stats => {
var r = '';
stats.forEach(report => {
if (report.type === 'inbound-rtp' && report.kind === 'video') {
r = 'pkts=' + (report.packetsReceived||0) + ' bytes=' + (report.bytesReceived||0) + ' lost=' + (report.packetsLost||0) + ' dropped=' + (report.framesDropped||0) + ' dec=' + (report.decoderImplementation||'?') + ' kf=' + (report.keyFramesDecoded||0) + ' nacks=' + (report.nackCount||0) + ' plis=' + (report.pliCount||0);
}
});
return r;
});
}''')
except Exception:
rtc_stats = ''
print(' [' + str(i) + 's] ws=' + ws_stat + ' rtc=' + rtc_stat + ' frames=' + frames_stat + ' ' + str(rtc_stats))
frames_val = int(frames_stat) if frames_stat and frames_stat != '--' else 0
except ValueError:
frames_val = 0
if frames_val > 0:
result['playback'] = True
print(' [' + str(i) + 's] frames=' + str(frames_val) + ' fps=' + fps_stat + ' rtc=' + rtc_stat)
break
print(' [' + str(i) + 's] ws=' + ws_stat + ' rtc=' + rtc_stat + ' frames=' + frames_stat)
if not result.get('playback'):
log_text = page.locator('#log').text_content()
@@ -376,36 +267,26 @@ def test_mode(mode : str) -> dict:
def main() -> None:
modes_to_test = MODES
if len(sys.argv) > 1:
modes_to_test = sys.argv[1:]
results = []
for mode in modes_to_test:
result = test_mode(mode)
results.append(result)
result = test_rtc()
print('\n\n' + '=' * 60)
print('SUMMARY')
print('RESULT')
print('=' * 60)
for r in results:
status = 'PASS' if r.get('playback') else 'FAIL'
error = ' (' + r.get('error', '') + ')' if r.get('error') else ''
flags = []
status = 'PASS' if result.get('playback') else 'FAIL'
error = ' (' + result.get('error', '') + ')' if result.get('error') else ''
flags = []
if r.get('session'):
flags.append('session')
if r.get('ws_open'):
flags.append('ws')
if r.get('stream_ready'):
flags.append('ready')
if r.get('playback'):
flags.append('playback')
if result.get('session'):
flags.append('session')
if result.get('ws_open'):
flags.append('ws')
if result.get('stream_ready'):
flags.append('ready')
if result.get('playback'):
flags.append('playback')
print(' ' + status + ' ' + r.get('mode') + ' [' + ','.join(flags) + ']' + error)
print(' ' + status + ' datachannel-direct [' + ','.join(flags) + ']' + error)
if __name__ == '__main__':
-336
View File
@@ -1,336 +0,0 @@
import asyncio
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional
import numpy
from aiortc import RTCPeerConnection, RTCSessionDescription, AudioStreamTrack, VideoStreamTrack
from av import AudioFrame, VideoFrame
from facefusion import logger
from facefusion.types import VisionFrame
BRIDGE_PORT_START : int = 8893
AUDIO_SAMPLE_RATE : int = 48000
class FramePushTrack(VideoStreamTrack):
kind = 'video'
def __init__(self) -> None:
super().__init__()
self._frame : Optional[VisionFrame] = None
self._lock = threading.Lock()
self._started = False
def push(self, vision_frame : VisionFrame) -> None:
with self._lock:
self._frame = vision_frame
async def recv(self) -> VideoFrame:
pts, time_base = await self.next_timestamp()
with self._lock:
frame_data = self._frame
if frame_data is None:
frame_data = numpy.zeros((240, 320, 3), dtype = numpy.uint8)
if not self._started:
self._started = True
logger.info('aiortc track sending first frame', __name__)
video_frame = VideoFrame.from_ndarray(frame_data, format = 'bgr24')
video_frame.pts = pts
video_frame.time_base = time_base
return video_frame
class AudioPushTrack(AudioStreamTrack):
kind = 'audio'
def __init__(self) -> None:
super().__init__()
self._buffer = bytearray()
self._lock = threading.Lock()
self._pts = 0
self._frame_samples = 960
def push(self, pcm_data : bytes) -> None:
with self._lock:
self._buffer.extend(pcm_data)
if len(self._buffer) > AUDIO_SAMPLE_RATE * 4:
self._buffer = self._buffer[-AUDIO_SAMPLE_RATE * 4:]
async def recv(self) -> AudioFrame:
await asyncio.sleep(self._frame_samples / AUDIO_SAMPLE_RATE)
needed = self._frame_samples * 2 * 2
with self._lock:
if len(self._buffer) >= needed:
chunk = bytes(self._buffer[:needed])
del self._buffer[:needed]
else:
chunk = None
if chunk:
pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1)
else:
pcm = numpy.zeros((1, self._frame_samples * 2), dtype = numpy.int16)
audio_frame = AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo')
audio_frame.sample_rate = AUDIO_SAMPLE_RATE
audio_frame.pts = self._pts
self._pts += self._frame_samples
return audio_frame
class AiortcBridge:
def __init__(self) -> None:
global BRIDGE_PORT_START
self.port = BRIDGE_PORT_START
BRIDGE_PORT_START += 1
self.video_track = FramePushTrack()
self.audio_track = AudioPushTrack()
self.pcs : list = []
self._http_thread : Optional[threading.Thread] = None
self._running = False
self._has_viewer = False
self._loop = None
async def start(self) -> None:
self._running = True
self._loop = asyncio.get_event_loop()
self._http_thread = threading.Thread(target = self._run_http, daemon = True)
self._http_thread.start()
logger.info('aiortc bridge started on port ' + str(self.port), __name__)
async def stop(self) -> None:
self._running = False
for pc in self.pcs:
try:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(pc.close(), loop)
except Exception:
pass
def push_frame(self, vision_frame : VisionFrame) -> None:
self.video_track.push(vision_frame)
def push_audio(self, audio_data : bytes) -> None:
self.audio_track.push(audio_data)
def has_viewer(self) -> bool:
return self._has_viewer
def _handle_whep(self, sdp_offer : str) -> Optional[str]:
if not self._loop:
return None
future = asyncio.run_coroutine_threadsafe(self._create_pc(sdp_offer), self._loop)
try:
return future.result(timeout = 10)
except Exception as exception:
logger.error('whep error: ' + str(exception), __name__)
return None
async def _create_pc(self, sdp_offer : str) -> Optional[str]:
pc = RTCPeerConnection()
self.pcs.append(pc)
pc.addTrack(self.video_track)
pc.addTrack(self.audio_track)
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
self._has_viewer = True
return pc.localDescription.sdp
def _run_http(self) -> None:
bridge = self
class WhepHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.end_headers()
def do_POST(self) -> None:
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
answer = bridge._handle_whep(body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
else:
self.send_response(500)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
server = HTTPServer(('0.0.0.0', self.port), WhepHandler)
server.timeout = 1
while self._running:
server.handle_request()
class WhipAiortcBridge:
def __init__(self) -> None:
global BRIDGE_PORT_START
self.port = BRIDGE_PORT_START
BRIDGE_PORT_START += 1
self.whip_port = BRIDGE_PORT_START
BRIDGE_PORT_START += 1
self._ingest_pc = None
self._relay_track = None
self._viewer_pcs : list = []
self._http_thread : Optional[threading.Thread] = None
self._running = False
self._loop = None
self._ingest_ready = False
async def start(self) -> None:
self._running = True
self._loop = asyncio.get_event_loop()
self._http_thread = threading.Thread(target = self._run_http, daemon = True)
self._http_thread.start()
logger.info('whip-aiortc bridge whip=' + str(self.whip_port) + ' whep=' + str(self.port), __name__)
async def stop(self) -> None:
self._running = False
if self._ingest_pc:
await self._ingest_pc.close()
for pc in self._viewer_pcs:
await pc.close()
def get_whip_url(self) -> str:
return 'http://localhost:' + str(self.whip_port) + '/whip'
def get_whep_url(self) -> str:
return 'http://localhost:' + str(self.port) + '/whep'
def is_ready(self) -> bool:
return self._ingest_ready
def _handle_whip(self, sdp_offer : str) -> Optional[str]:
if not self._loop:
return None
future = asyncio.run_coroutine_threadsafe(self._create_ingest(sdp_offer), self._loop)
try:
return future.result(timeout = 10)
except Exception as exception:
logger.error('whip ingest error: ' + str(exception), __name__)
return None
def _handle_whep(self, sdp_offer : str) -> Optional[str]:
if not self._loop:
return None
future = asyncio.run_coroutine_threadsafe(self._create_viewer(sdp_offer), self._loop)
try:
return future.result(timeout = 10)
except Exception as exception:
logger.error('whep error: ' + str(exception), __name__)
return None
async def _create_ingest(self, sdp_offer : str) -> Optional[str]:
from aiortc import MediaStreamTrack
from aiortc.contrib.media import MediaRelay
pc = RTCPeerConnection()
self._ingest_pc = pc
self._relay = MediaRelay()
@pc.on('track')
def on_track(track : MediaStreamTrack) -> None:
if track.kind == 'video':
self._relay_track = self._relay.subscribe(track)
self._ingest_ready = True
logger.info('whip ingest video track received', __name__)
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return pc.localDescription.sdp
async def _create_viewer(self, sdp_offer : str) -> Optional[str]:
pc = RTCPeerConnection()
self._viewer_pcs.append(pc)
if self._relay_track:
pc.addTrack(self._relay_track)
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return pc.localDescription.sdp
def _run_http(self) -> None:
bridge = self
class Handler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, DELETE')
self.send_header('Access-Control-Allow-Headers', 'Content-Type, Authorization')
self.end_headers()
def do_POST(self) -> None:
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
path = self.path
if '/whip' in path:
answer = bridge._handle_whip(body)
elif '/whep' in path:
answer = bridge._handle_whep(body)
else:
self.send_response(404)
self.end_headers()
return
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Expose-Headers', 'Location')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
else:
self.send_response(500)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
whip_server = HTTPServer(('0.0.0.0', self.whip_port), Handler)
whip_server.timeout = 0.5
whep_server = HTTPServer(('0.0.0.0', self.port), Handler)
whep_server.timeout = 0.5
while self._running:
whip_server.handle_request()
whep_server.handle_request()
+4 -43
View File
@@ -6,37 +6,19 @@ from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route, WebSocketRoute
from facefusion import logger, mediamtx
from facefusion import logger
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
from facefusion.apis.endpoints.capabilities import get_capabilities
from facefusion.apis.endpoints.metrics import get_metrics, websocket_metrics
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
from facefusion.common_helper import is_linux, is_windows
from facefusion.apis.endpoints.stream import websocket_stream, websocket_stream_audio, websocket_stream_live, websocket_stream_mjpeg, websocket_stream_rtc, websocket_stream_rtc_relay, websocket_stream_whip, websocket_stream_whip_dc, websocket_stream_whip_py
from facefusion.apis.endpoints.stream import post_whep, websocket_stream, websocket_stream_rtc
from facefusion.apis.middlewares.session import create_session_guard
@asynccontextmanager
async def lifespan(app : Starlette) -> AsyncGenerator[None, None]:
if is_linux():
mediamtx.start()
mediamtx.wait_for_ready()
try:
from facefusion import webrtc_sfu
webrtc_sfu.start()
except Exception as exception:
logger.warn('webrtc sfu: ' + str(exception), __name__)
try:
from facefusion import whip_relay
whip_relay.start()
whip_relay.wait_for_ready()
except Exception as exception:
logger.warn('whip relay: ' + str(exception), __name__)
try:
from facefusion import rtc
rtc.start()
@@ -45,21 +27,6 @@ async def lifespan(app : Starlette) -> AsyncGenerator[None, None]:
yield
if is_linux():
mediamtx.stop()
try:
from facefusion import webrtc_sfu
webrtc_sfu.stop()
except Exception:
pass
try:
from facefusion import whip_relay
whip_relay.stop()
except Exception:
pass
try:
from facefusion import rtc
rtc.stop()
@@ -85,15 +52,9 @@ def create_api() -> Starlette:
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]),
Route('/stream/{session_id}/whep', post_whep, methods = [ 'POST' ]),
WebSocketRoute('/stream', websocket_stream, middleware = [ session_guard ]),
WebSocketRoute('/stream/whip', websocket_stream_whip, middleware = [ session_guard ]),
WebSocketRoute('/stream/whip-py', websocket_stream_whip_py, middleware = [ session_guard ]),
WebSocketRoute('/stream/whip-dc', websocket_stream_whip_dc, middleware = [ session_guard ]),
WebSocketRoute('/stream/live', websocket_stream_live, middleware = [ session_guard ]),
WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ]),
WebSocketRoute('/stream/rtc-relay', websocket_stream_rtc_relay, middleware = [ session_guard ]),
WebSocketRoute('/stream/mjpeg', websocket_stream_mjpeg, middleware = [ session_guard ]),
WebSocketRoute('/stream/audio', websocket_stream_audio, middleware = [ session_guard ])
WebSocketRoute('/stream/rtc', websocket_stream_rtc, middleware = [ session_guard ])
]
api = Starlette(routes = routes, lifespan = lifespan)
+32 -710
View File
@@ -1,25 +1,25 @@
import asyncio
import os as _os
import os
import threading
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from typing import Deque, List
from typing import List
import cv2
import numpy
from starlette.requests import Request
from starlette.responses import Response
from starlette.websockets import WebSocket
from facefusion import logger, session_context, session_manager, state_manager
from facefusion.common_helper import is_windows
from facefusion.apis.stream_helper import STREAM_AUDIO_RATE
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.apis.session_helper import extract_access_token
from facefusion import mediamtx
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, close_fmp4_encoder, close_whip_encoder, collect_fmp4_chunks, create_fmp4_encoder, create_vp8_pipe_encoder, create_whip_encoder, feed_whip_audio, feed_whip_frame, process_stream_frame, read_fmp4_output
from facefusion.apis.stream_helper import STREAM_FPS, STREAM_QUALITY, create_vp8_pipe_encoder, feed_whip_frame, process_stream_frame
from facefusion.streamer import process_vision_frame
from facefusion.types import VisionFrame
JPEG_MAGIC : bytes = b'\xff\xd8'
@@ -52,454 +52,16 @@ async def websocket_stream(websocket : WebSocket) -> None:
await websocket.close()
def run_whip_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, audio_write_fd_holder : list, stream_path : str) -> None:
encoder = None
audio_write_fd = -1
output_deque : Deque[VisionFrame] = deque()
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
while not stop_event.is_set():
with lock:
capture_frame = latest_frame_holder[0]
latest_frame_holder[0] = None
if capture_frame is not None:
future = executor.submit(process_stream_frame, capture_frame)
futures.append(future)
for future_done in [ future for future in futures if future.done() ]:
output_deque.append(future_done.result())
futures.remove(future_done)
if encoder and encoder.poll() is not None:
stderr_output = encoder.stderr.read() if encoder.stderr else b''
logger.error('encoder died with code ' + str(encoder.returncode) + ': ' + stderr_output.decode(), __name__)
break
while output_deque:
temp_vision_frame = output_deque.popleft()
if not encoder:
height, width = temp_vision_frame.shape[:2]
whip_url = mediamtx.get_whip_url(stream_path)
encoder, audio_write_fd = create_whip_encoder(width, height, STREAM_FPS, STREAM_QUALITY, whip_url)
audio_write_fd_holder[0] = audio_write_fd
logger.info('whip encoder started ' + str(width) + 'x' + str(height), __name__)
feed_whip_frame(encoder, temp_vision_frame)
if encoder and not ready_event.is_set() and mediamtx.is_path_ready(stream_path):
ready_event.set()
if capture_frame is None and not output_deque:
time.sleep(0.005)
if encoder:
close_whip_encoder(encoder, audio_write_fd)
async def websocket_stream_whip(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
stream_path = 'stream/' + session_id
mediamtx.remove_path(stream_path)
mediamtx.add_path(stream_path)
logger.info('mediamtx path added ' + stream_path, __name__)
latest_frame_holder : list = [None]
audio_write_fd_holder : list = [-1]
whep_sent = False
lock = threading.Lock()
stop_event = threading.Event()
ready_event = threading.Event()
worker = threading.Thread(target = run_whip_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, audio_write_fd_holder, stream_path), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
if not whep_sent and ready_event.is_set():
whep_url = mediamtx.get_whep_url(stream_path)
await websocket.send_text(whep_url)
whep_sent = True
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC and audio_write_fd_holder[0] > 0:
feed_whip_audio(audio_write_fd_holder[0], data)
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
mediamtx.remove_path(stream_path)
return
await websocket.close()
def run_audio_silence_feeder(audio_write_fd_holder : list, stop_event : threading.Event, audio_active_event : threading.Event) -> None:
frame_bytes = STREAM_AUDIO_RATE // 50 * 2 * 2
silence = b'\x00' * frame_bytes
while not stop_event.is_set():
if not audio_active_event.is_set():
fd = audio_write_fd_holder[0]
if fd > 0:
try:
_os.write(fd, silence)
except OSError:
break
time.sleep(0.02)
def run_fmp4_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, output_chunks : List[bytes], output_lock : threading.Lock, audio_write_fd_holder : list, audio_active_event : threading.Event) -> None:
encoder = None
audio_write_fd = -1
reader_thread = None
output_deque : Deque[VisionFrame] = deque()
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
while not stop_event.is_set():
with lock:
capture_frame = latest_frame_holder[0]
latest_frame_holder[0] = None
if capture_frame is not None:
future = executor.submit(process_stream_frame, capture_frame)
futures.append(future)
for future_done in [ future for future in futures if future.done() ]:
output_deque.append(future_done.result())
futures.remove(future_done)
if encoder and encoder.poll() is not None:
stderr_output = encoder.stderr.read() if encoder.stderr else b''
logger.error('fmp4 encoder died with code ' + str(encoder.returncode) + ': ' + stderr_output.decode(), __name__)
break
while output_deque:
temp_vision_frame = output_deque.popleft()
if not encoder:
height, width = temp_vision_frame.shape[:2]
encoder, audio_write_fd = create_fmp4_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
audio_write_fd_holder[0] = audio_write_fd
reader_thread = threading.Thread(target = read_fmp4_output, args = (encoder, output_chunks, output_lock), daemon = True)
reader_thread.start()
silence_thread = threading.Thread(target = run_audio_silence_feeder, args = (audio_write_fd_holder, stop_event, audio_active_event), daemon = True)
silence_thread.start()
logger.info('fmp4 encoder started ' + str(width) + 'x' + str(height), __name__)
feed_whip_frame(encoder, temp_vision_frame)
if capture_frame is None and not output_deque:
time.sleep(0.005)
if encoder:
close_fmp4_encoder(encoder, audio_write_fd)
async def websocket_stream_live(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
latest_frame_holder : list = [None]
audio_write_fd_holder : list = [-1]
output_chunks : List[bytes] = []
lock = threading.Lock()
output_lock = threading.Lock()
stop_event = threading.Event()
audio_active_event = threading.Event()
worker = threading.Thread(target = run_fmp4_pipeline, args = (latest_frame_holder, lock, stop_event, output_chunks, output_lock, audio_write_fd_holder, audio_active_event), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
chunks = collect_fmp4_chunks(output_chunks, output_lock)
if chunks:
await websocket.send_bytes(chunks)
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC and audio_write_fd_holder[0] > 0:
audio_active_event.set()
feed_whip_audio(audio_write_fd_holder[0], data)
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
return
await websocket.close()
def run_mjpeg_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, output_holder : list, output_lock : threading.Lock) -> None:
output_deque : Deque[VisionFrame] = deque()
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
while not stop_event.is_set():
with lock:
capture_frame = latest_frame_holder[0]
latest_frame_holder[0] = None
if capture_frame is not None:
future = executor.submit(process_stream_frame, capture_frame)
futures.append(future)
for future_done in [ future for future in futures if future.done() ]:
output_deque.append(future_done.result())
futures.remove(future_done)
while output_deque:
temp_vision_frame = output_deque.popleft()
is_success, encoded = cv2.imencode('.jpg', temp_vision_frame, [cv2.IMWRITE_JPEG_QUALITY, 92])
if is_success:
with output_lock:
output_holder[0] = encoded.tobytes()
if capture_frame is None and not output_deque:
time.sleep(0.005)
async def websocket_stream_mjpeg(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
latest_frame_holder : list = [None]
output_holder : list = [None]
lock = threading.Lock()
output_lock = threading.Lock()
stop_event = threading.Event()
worker = threading.Thread(target = run_mjpeg_pipeline, args = (latest_frame_holder, lock, stop_event, output_holder, output_lock), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
with output_lock:
jpeg_data = output_holder[0]
output_holder[0] = None
if jpeg_data:
await websocket.send_bytes(jpeg_data)
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
return
await websocket.close()
async def websocket_stream_audio(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
await websocket.accept(subprotocol = subprotocol)
try:
while True:
message = await websocket.receive()
if message.get('bytes'):
await websocket.send_bytes(message.get('bytes'))
except Exception:
pass
async def websocket_stream_whip_py(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
from facefusion.aiortc_bridge import AiortcBridge
bridge = AiortcBridge()
await bridge.start()
whep_url = 'http://localhost:' + str(bridge.port) + '/whep'
latest_frame_holder : list = [None]
whep_sent = False
lock = threading.Lock()
stop_event = threading.Event()
ready_event = threading.Event()
worker = threading.Thread(target = run_aiortc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, bridge), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
if not whep_sent and ready_event.is_set():
await websocket.send_text(whep_url)
whep_sent = True
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC:
bridge.push_audio(data)
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
await bridge.stop()
return
await websocket.close()
def run_aiortc_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, bridge : object) -> None:
output_deque : Deque[VisionFrame] = deque()
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
while not stop_event.is_set():
with lock:
capture_frame = latest_frame_holder[0]
latest_frame_holder[0] = None
if capture_frame is not None:
future = executor.submit(process_stream_frame, capture_frame)
futures.append(future)
for future_done in [ future for future in futures if future.done() ]:
output_deque.append(future_done.result())
futures.remove(future_done)
while output_deque:
temp_vision_frame = output_deque.popleft()
bridge.push_frame(temp_vision_frame)
if not ready_event.is_set():
time.sleep(2)
ready_event.set()
if capture_frame is None and not output_deque:
time.sleep(0.005)
def read_h264_output(process, h264_chunks : List[bytes], h264_lock : threading.Lock) -> None:
fd = process.stdout.fileno()
if not is_windows():
import fcntl
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~_os.O_NONBLOCK)
while True:
chunk = _os.read(fd, 4096)
if not chunk:
break
with h264_lock:
h264_chunks.append(chunk)
def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lock) -> None:
fd = process.stdout.fileno()
pipe_handle = process.stdout.fileno()
if not is_windows():
import fcntl
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags & ~_os.O_NONBLOCK)
if os.name != 'nt':
os.set_blocking(pipe_handle, True)
header = b''
while len(header) < 32:
chunk = _os.read(fd, 32 - len(header))
chunk = os.read(pipe_handle, 32 - len(header))
if not chunk:
return
@@ -510,7 +72,7 @@ def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lo
frame_header = b''
while len(frame_header) < 12:
chunk = _os.read(fd, 12 - len(frame_header))
chunk = os.read(pipe_handle, 12 - len(frame_header))
if not chunk:
return
@@ -521,7 +83,7 @@ def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lo
frame_data = b''
while len(frame_data) < frame_size:
chunk = _os.read(fd, frame_size - len(frame_data))
chunk = os.read(pipe_handle, frame_size - len(frame_data))
if not chunk:
return
@@ -532,203 +94,13 @@ def read_ivf_frames(process, frame_list : List[bytes], frame_lock : threading.Lo
frame_list.append(frame_data)
def run_h264_dc_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, backend : str, stream_path : str, rtp_port : int) -> None:
encoder = None
reader_thread = None
vp8_frames : List[bytes] = []
vp8_lock = threading.Lock()
output_deque : Deque[VisionFrame] = deque()
udp_sock = None
if backend == 'relay':
import socket as sock
udp_sock = sock.socket(sock.AF_INET, sock.SOCK_DGRAM)
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
while not stop_event.is_set():
with lock:
capture_frame = latest_frame_holder[0]
latest_frame_holder[0] = None
if capture_frame is not None:
future = executor.submit(process_stream_frame, capture_frame)
futures.append(future)
for future_done in [ future for future in futures if future.done() ]:
output_deque.append(future_done.result())
futures.remove(future_done)
if encoder and encoder.poll() is not None:
stderr_output = encoder.stderr.read() if encoder.stderr else b''
logger.error('vp8 encoder died: ' + stderr_output.decode(), __name__)
break
while output_deque:
temp_vision_frame = output_deque.popleft()
if not encoder:
height, width = temp_vision_frame.shape[:2]
encoder = create_vp8_pipe_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
reader_thread = threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True)
reader_thread.start()
logger.info('vp8 encoder started ' + str(width) + 'x' + str(height) + ' [' + backend + ']', __name__)
feed_whip_frame(encoder, temp_vision_frame)
with vp8_lock:
if vp8_frames:
pending = list(vp8_frames)
vp8_frames.clear()
for frame in pending:
if backend == 'relay' and udp_sock:
if len(frame) <= 64999:
udp_sock.sendto(b'\x01' + frame, ('127.0.0.1', rtp_port))
if backend == 'rtc':
from facefusion import rtc
rtc.send_vp8_frame(stream_path, frame)
if not ready_event.is_set() and encoder and encoder.poll() is None:
time.sleep(1)
ready_event.set()
if capture_frame is None and not output_deque:
time.sleep(0.005)
if encoder:
encoder.stdin.close()
encoder.terminate()
encoder.wait(timeout = 5)
if udp_sock:
udp_sock.close()
async def websocket_stream_whip_dc(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
from facefusion.aiortc_bridge import AiortcBridge
bridge = AiortcBridge()
await bridge.start()
whep_url = 'http://localhost:' + str(bridge.port) + '/whep'
latest_frame_holder : list = [None]
whep_sent = False
lock = threading.Lock()
stop_event = threading.Event()
ready_event = threading.Event()
worker = threading.Thread(target = run_aiortc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, bridge), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
if not whep_sent and ready_event.is_set():
await websocket.send_text(whep_url)
whep_sent = True
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC:
bridge.push_audio(data)
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
await bridge.stop()
return
await websocket.close()
async def websocket_stream_whip_aio(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
await websocket.accept(subprotocol = subprotocol)
if source_paths:
from facefusion.aiortc_bridge import AiortcBridge
bridge = AiortcBridge()
await bridge.start()
whep_url = 'http://localhost:' + str(bridge.port) + '/whep'
latest_frame_holder : list = [None]
whep_sent = False
lock = threading.Lock()
stop_event = threading.Event()
ready_event = threading.Event()
worker = threading.Thread(target = run_aiortc_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, bridge), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
if not whep_sent and ready_event.is_set():
await websocket.send_text(whep_url)
whep_sent = True
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC:
bridge.push_audio(data)
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
await bridge.stop()
return
await websocket.close()
def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, stop_event : threading.Event, ready_event : threading.Event, stream_path : str) -> None:
from facefusion import rtc
encoder = None
reader_thread = None
vp8_frames : List[bytes] = []
vp8_lock = threading.Lock()
output_deque : Deque[VisionFrame] = deque()
output_deque : deque = deque()
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor:
futures = []
@@ -764,8 +136,7 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s
if not encoder:
height, width = temp_vision_frame.shape[:2]
encoder = create_vp8_pipe_encoder(width, height, STREAM_FPS, STREAM_QUALITY)
reader_thread = threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True)
reader_thread.start()
threading.Thread(target = read_ivf_frames, args = (encoder, vp8_frames, vp8_lock), daemon = True).start()
logger.info('vp8 direct encoder started ' + str(width) + 'x' + str(height), __name__)
feed_whip_frame(encoder, temp_vision_frame)
@@ -791,71 +162,6 @@ def run_rtc_direct_pipeline(latest_frame_holder : list, lock : threading.Lock, s
encoder.wait(timeout = 5)
async def websocket_stream_rtc_relay(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
session_id = session_manager.find_session_id(access_token)
session_context.set_session_id(session_id)
source_paths = state_manager.get_item('source_paths')
logger.info('rtc-relay: session_id=' + str(session_id) + ' source_paths=' + str(bool(source_paths)), __name__)
await websocket.accept(subprotocol = subprotocol)
if source_paths:
from facefusion import rtc
if not rtc.lib:
logger.error('rtc-relay: libdatachannel not loaded', __name__)
await websocket.close()
return
stream_path = 'stream/' + session_id
rtc.create_session(stream_path)
whep_url = 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep'
latest_frame_holder : list = [None]
whep_sent = False
lock = threading.Lock()
stop_event = threading.Event()
ready_event = threading.Event()
worker = threading.Thread(target = run_rtc_direct_pipeline, args = (latest_frame_holder, lock, stop_event, ready_event, stream_path), daemon = True)
worker.start()
try:
while True:
message = await websocket.receive()
if not whep_sent and ready_event.is_set():
await websocket.send_text(whep_url)
whep_sent = True
if message.get('bytes'):
data = message.get('bytes')
if data[:2] == JPEG_MAGIC:
frame = cv2.imdecode(numpy.frombuffer(data, numpy.uint8), cv2.IMREAD_COLOR)
if numpy.any(frame):
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC:
rtc.send_audio(stream_path, data)
except Exception as exception:
logger.error(str(exception), __name__)
stop_event.set()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, worker.join, 10)
rtc.destroy_session(stream_path)
return
await websocket.close()
async def websocket_stream_rtc(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
access_token = extract_access_token(websocket.scope)
@@ -868,9 +174,10 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None:
if source_paths:
from facefusion import rtc
stream_path = 'stream/' + session_id
rtc.create_session(stream_path)
whep_url = 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep'
whep_url = '/' + stream_path + '/whep'
latest_frame_holder : list = [None]
whep_sent = False
@@ -911,3 +218,18 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None:
return
await websocket.close()
async def post_whep(request : Request) -> Response:
from facefusion import rtc
session_id = request.path_params.get('session_id')
stream_path = 'stream/' + session_id
body = await request.body()
sdp_offer = body.decode('utf-8')
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(None, rtc.handle_whep_offer, stream_path, sdp_offer)
if answer:
return Response(answer, status_code = 201, media_type = 'application/sdp')
return Response(status_code = 404)
+6 -188
View File
@@ -1,21 +1,13 @@
import os
import subprocess
import tempfile
import threading
from typing import List, Optional, Tuple
import cv2
from facefusion import ffmpeg_builder
from facefusion.common_helper import is_windows
from facefusion.streamer import process_vision_frame
from facefusion.types import VisionFrame
STREAM_FPS : int = 30
STREAM_QUALITY : int = 80
STREAM_AUDIO_RATE : int = 48000
DTLS_CERT_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_cert.pem')
DTLS_KEY_FILE : str = os.path.join(tempfile.gettempdir(), 'facefusion_dtls_key.pem')
def compute_bitrate(width : int, height : int) -> str:
@@ -46,186 +38,6 @@ def compute_bufsize(width : int, height : int) -> str:
return '10000k'
def create_dtls_certificate() -> None:
if os.path.isfile(DTLS_CERT_FILE) and os.path.isfile(DTLS_KEY_FILE):
return
subprocess.run([
'openssl', 'req', '-x509', '-newkey', 'ec', '-pkeyopt', 'ec_paramgen_curve:prime256v1',
'-keyout', DTLS_KEY_FILE, '-out', DTLS_CERT_FILE,
'-days', '365', '-nodes', '-subj', '/CN=facefusion'
], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
def create_whip_encoder(width : int, height : int, stream_fps : int, stream_quality : int, whip_url : str) -> Tuple[subprocess.Popen[bytes], int]:
create_dtls_certificate()
audio_read_fd, audio_write_fd = os.pipe()
commands = ffmpeg_builder.chain(
[ '-use_wallclock_as_timestamps', '1' ],
ffmpeg_builder.capture_video(),
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
ffmpeg_builder.set_input('-'),
[ '-use_wallclock_as_timestamps', '1' ],
[ '-f', 's16le', '-ar', str(STREAM_AUDIO_RATE), '-ac', '2', '-i', 'pipe:' + str(audio_read_fd) ],
ffmpeg_builder.set_video_encoder('libx264'),
ffmpeg_builder.set_video_quality('libx264', stream_quality),
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
[ '-pix_fmt', 'yuv420p' ],
[ '-profile:v', 'baseline' ],
[ '-tune', 'zerolatency' ],
[ '-maxrate', compute_bitrate(width, height) ],
[ '-bufsize', compute_bufsize(width, height) ],
[ '-g', str(stream_fps) ],
[ '-c:a', 'libopus' ],
[ '-f', 'whip' ],
[ '-cert_file', DTLS_CERT_FILE ],
[ '-key_file', DTLS_KEY_FILE ],
ffmpeg_builder.set_output(whip_url)
)
commands = ffmpeg_builder.run(commands)
if is_windows():
os.set_inheritable(audio_read_fd, True)
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE, close_fds = False)
else:
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE, pass_fds = (audio_read_fd,))
os.close(audio_read_fd)
return process, audio_write_fd
def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None:
raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
process.stdin.write(raw_bytes)
process.stdin.flush()
def feed_whip_audio(audio_write_fd : int, audio_data : bytes) -> None:
os.write(audio_write_fd, audio_data)
def close_whip_encoder(process : subprocess.Popen[bytes], audio_write_fd : int) -> None:
os.close(audio_write_fd)
process.stdin.close()
process.terminate()
process.wait(timeout = 5)
def create_fmp4_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> Tuple[subprocess.Popen[bytes], int]:
audio_read_fd, audio_write_fd = os.pipe()
commands = ffmpeg_builder.chain(
[ '-use_wallclock_as_timestamps', '1' ],
ffmpeg_builder.capture_video(),
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
ffmpeg_builder.set_input('-'),
[ '-use_wallclock_as_timestamps', '1' ],
[ '-f', 's16le', '-ar', str(STREAM_AUDIO_RATE), '-ac', '2', '-i', 'pipe:' + str(audio_read_fd) ],
[ '-thread_queue_size', '512' ],
ffmpeg_builder.set_video_encoder('libx264'),
ffmpeg_builder.set_video_quality('libx264', stream_quality),
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
[ '-pix_fmt', 'yuv420p' ],
[ '-profile:v', 'baseline' ],
[ '-tune', 'zerolatency' ],
[ '-maxrate', compute_bitrate(width, height) ],
[ '-bufsize', compute_bufsize(width, height) ],
[ '-g', str(stream_fps) ],
[ '-c:a', 'aac' ],
[ '-b:a', '128k' ],
[ '-f', 'mp4' ],
[ '-movflags', 'frag_keyframe+empty_moov+default_base_moof+frag_every_frame' ],
ffmpeg_builder.set_output('-')
)
commands = ffmpeg_builder.run(commands)
if is_windows():
os.set_inheritable(audio_read_fd, True)
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE, close_fds = False)
else:
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE, pass_fds = (audio_read_fd,))
os.close(audio_read_fd)
return process, audio_write_fd
def read_fmp4_output(process : subprocess.Popen[bytes], output_chunks : List[bytes], lock : threading.Lock) -> None:
while True:
chunk = process.stdout.read(4096)
if not chunk:
break
with lock:
output_chunks.append(chunk)
def collect_fmp4_chunks(output_chunks : List[bytes], lock : threading.Lock) -> Optional[bytes]:
with lock:
if output_chunks:
encoded_bytes = b''.join(output_chunks)
output_chunks.clear()
return encoded_bytes
return None
def close_fmp4_encoder(process : subprocess.Popen[bytes], audio_write_fd : int) -> None:
if audio_write_fd > 0:
os.close(audio_write_fd)
process.stdin.close()
process.terminate()
process.wait(timeout = 5)
def create_rtp_encoder(width : int, height : int, stream_fps : int, stream_quality : int, rtp_port : int) -> subprocess.Popen[bytes]:
commands = ffmpeg_builder.chain(
[ '-use_wallclock_as_timestamps', '1' ],
ffmpeg_builder.capture_video(),
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
ffmpeg_builder.set_input('-'),
ffmpeg_builder.set_video_encoder('libx264'),
ffmpeg_builder.set_video_quality('libx264', stream_quality),
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
[ '-pix_fmt', 'yuv420p' ],
[ '-profile:v', 'baseline' ],
[ '-tune', 'zerolatency' ],
[ '-maxrate', compute_bitrate(width, height) ],
[ '-bufsize', compute_bufsize(width, height) ],
[ '-g', str(stream_fps) ],
[ '-an' ],
[ '-f', 'rtp' ],
[ '-payload_type', '96' ],
ffmpeg_builder.set_output('rtp://127.0.0.1:' + str(rtp_port) + '?pkt_size=1200')
)
commands = ffmpeg_builder.run(commands)
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stderr = subprocess.PIPE)
return process
def create_h264_pipe_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
commands = ffmpeg_builder.chain(
[ '-use_wallclock_as_timestamps', '1' ],
ffmpeg_builder.capture_video(),
ffmpeg_builder.set_media_resolution(str(width) + 'x' + str(height)),
ffmpeg_builder.set_input('-'),
ffmpeg_builder.set_video_encoder('libx264'),
ffmpeg_builder.set_video_quality('libx264', stream_quality),
ffmpeg_builder.set_video_preset('libx264', 'ultrafast'),
[ '-pix_fmt', 'yuv420p' ],
[ '-profile:v', 'baseline' ],
[ '-tune', 'zerolatency' ],
[ '-maxrate', compute_bitrate(width, height) ],
[ '-bufsize', compute_bufsize(width, height) ],
[ '-g', '1' ],
[ '-an' ],
[ '-f', 'h264' ],
ffmpeg_builder.set_output('-')
)
commands = ffmpeg_builder.run(commands)
process = subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
return process
def create_vp8_pipe_encoder(width : int, height : int, stream_fps : int, stream_quality : int) -> subprocess.Popen[bytes]:
commands = ffmpeg_builder.chain(
[ '-use_wallclock_as_timestamps', '1' ],
@@ -255,5 +67,11 @@ def create_vp8_pipe_encoder(width : int, height : int, stream_fps : int, stream_
return process
def feed_whip_frame(process : subprocess.Popen[bytes], vision_frame : VisionFrame) -> None:
raw_bytes = cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB).tobytes()
process.stdin.write(raw_bytes)
process.stdin.flush()
def process_stream_frame(vision_frame : VisionFrame) -> VisionFrame:
return process_vision_frame(vision_frame)
-107
View File
@@ -1,107 +0,0 @@
import os
import shutil
import subprocess
import time
from typing import Optional
import httpx
from facefusion.common_helper import is_linux
MEDIAMTX_WHIP_PORT : int = 8889
MEDIAMTX_API_PORT : int = 9997
MEDIAMTX_CONFIG : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'mediamtx.yml')
MEDIAMTX_FALLBACK_BINARY : str = '/home/henry/local/bin/mediamtx'
MEDIAMTX_PROCESS : Optional[subprocess.Popen[bytes]] = None
def get_whip_url(stream_path : str) -> str:
return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + stream_path + '/whip'
def get_whep_url(stream_path : str) -> str:
return 'http://localhost:' + str(MEDIAMTX_WHIP_PORT) + '/' + stream_path + '/whep'
def get_api_url() -> str:
return 'http://localhost:' + str(MEDIAMTX_API_PORT)
def resolve_binary() -> str:
mediamtx_path = shutil.which('mediamtx')
if mediamtx_path:
return mediamtx_path
return MEDIAMTX_FALLBACK_BINARY
def start() -> None:
global MEDIAMTX_PROCESS
stop_stale()
mediamtx_binary = resolve_binary()
MEDIAMTX_PROCESS = subprocess.Popen(
[ mediamtx_binary, MEDIAMTX_CONFIG ],
stdout = subprocess.DEVNULL,
stderr = subprocess.DEVNULL
)
def stop() -> None:
global MEDIAMTX_PROCESS
if MEDIAMTX_PROCESS:
MEDIAMTX_PROCESS.terminate()
MEDIAMTX_PROCESS.wait()
MEDIAMTX_PROCESS = None
def stop_stale() -> None:
if is_linux():
subprocess.run([ 'fuser', '-k', str(MEDIAMTX_WHIP_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', '8189/udp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
subprocess.run([ 'fuser', '-k', str(MEDIAMTX_API_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
time.sleep(1)
def wait_for_ready() -> bool:
api_url = get_api_url() + '/v3/paths/list'
for _ in range(10):
try:
response = httpx.get(api_url, timeout = 1)
if response.status_code == 200:
return True
except Exception:
pass
time.sleep(0.5)
return False
def is_path_ready(stream_path : str) -> bool:
api_url = get_api_url() + '/v3/paths/get/' + stream_path
try:
response = httpx.get(api_url, timeout = 1)
if response.status_code == 200:
return response.json().get('ready', False)
except Exception:
pass
return False
def add_path(stream_path : str) -> bool:
api_url = get_api_url() + '/v3/config/paths/add/' + stream_path
response = httpx.post(api_url, json = {}, timeout = 5)
return response.status_code == 200
def remove_path(stream_path : str) -> bool:
api_url = get_api_url() + '/v3/config/paths/delete/' + stream_path
response = httpx.delete(api_url, timeout = 5)
return response.status_code == 200
+19 -320
View File
@@ -2,45 +2,24 @@ import ctypes
import ctypes.util
import os
import threading
import time as _time
from http.server import BaseHTTPRequestHandler, HTTPServer
import time
from typing import Dict, List, Optional, TypeAlias
from facefusion import logger
from facefusion.common_helper import is_macos, is_windows
RtcLib : TypeAlias = ctypes.CDLL
WHEP_PORT : int = 8892
lib : Optional[RtcLib] = None
sessions : Dict[str, dict] = {}
http_thread : Optional[threading.Thread] = None
running : bool = False
RTC_NEW = 0
RTC_CONNECTING = 1
RTC_CONNECTED = 2
RTC_DISCONNECTED = 3
RTC_FAILED = 4
RTC_CLOSED = 5
RTC_GATHERING_NEW = 0
RTC_GATHERING_INPROGRESS = 1
RTC_GATHERING_COMPLETE = 2
RTC_DIRECTION_SENDONLY = 0
RTC_DIRECTION_RECVONLY = 1
RTC_DIRECTION_SENDRECV = 2
RTC_DIRECTION_INACTIVE = 3
RTC_DIRECTION_UNKNOWN = 4
LOG_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)
DESCRIPTION_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p)
CANDIDATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p)
STATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
GATHERING_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
TRACK_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
MESSAGE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p)
class RtcConfiguration(ctypes.Structure):
@@ -85,11 +64,16 @@ def find_library() -> Optional[str]:
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
bin_dir = os.path.join(project_root, 'bin')
if is_windows():
return os.path.join(bin_dir, 'windows-x64-openssl-h264-vp8-av1-opus-datachannel-0.24.1.dll')
if is_macos():
return os.path.join(bin_dir, 'macos-universal-openssl-h264-vp8-av1-opus-libdatachannel-0.24.1.dylib')
return os.path.join(bin_dir, 'linux-x64-openssl-h264-vp8-av1-opus-libdatachannel-0.24.1.so')
if not os.path.isdir(bin_dir):
return None
ext = '.dll' if os.name == 'nt' else '.so'
for name in os.listdir(bin_dir):
if 'datachannel' in name and name.endswith(ext):
return os.path.join(bin_dir, name)
return None
def load_library() -> bool:
@@ -121,9 +105,6 @@ def setup_prototypes() -> None:
lib.rtcDeletePeerConnection.argtypes = [ctypes.c_int]
lib.rtcDeletePeerConnection.restype = ctypes.c_int
lib.rtcSetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p]
lib.rtcSetLocalDescription.restype = ctypes.c_int
lib.rtcSetRemoteDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
lib.rtcSetRemoteDescription.restype = ctypes.c_int
@@ -133,9 +114,6 @@ def setup_prototypes() -> None:
lib.rtcAddTrack.argtypes = [ctypes.c_int, ctypes.c_char_p]
lib.rtcAddTrack.restype = ctypes.c_int
lib.rtcSetUserPointer.argtypes = [ctypes.c_int, ctypes.c_void_p]
lib.rtcSetUserPointer.restype = None
lib.rtcSetLocalDescriptionCallback.argtypes = [ctypes.c_int, DESCRIPTION_CALLBACK_TYPE]
lib.rtcSetLocalDescriptionCallback.restype = ctypes.c_int
@@ -148,18 +126,9 @@ def setup_prototypes() -> None:
lib.rtcSetGatheringStateChangeCallback.argtypes = [ctypes.c_int, GATHERING_CALLBACK_TYPE]
lib.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int
lib.rtcSetTrackCallback.argtypes = [ctypes.c_int, TRACK_CALLBACK_TYPE]
lib.rtcSetTrackCallback.restype = ctypes.c_int
lib.rtcSetMessageCallback.argtypes = [ctypes.c_int, MESSAGE_CALLBACK_TYPE]
lib.rtcSetMessageCallback.restype = ctypes.c_int
lib.rtcSendMessage.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
lib.rtcSendMessage.restype = ctypes.c_int
lib.rtcSetH264Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
lib.rtcSetH264Packetizer.restype = ctypes.c_int
lib.rtcSetVP8Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
lib.rtcSetVP8Packetizer.restype = ctypes.c_int
@@ -204,95 +173,11 @@ def create_peer_connection() -> int:
return lib.rtcCreatePeerConnection(ctypes.byref(config))
next_rtp_port : int = 16000
def create_session(stream_path : str) -> None:
global video_frame_count
video_frame_count = 0
sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': 0, 'rtp_fd': None}
def create_rtp_session(stream_path : str) -> int:
global next_rtp_port
import socket as sock
rtp_port = next_rtp_port
next_rtp_port += 1
rtp_fd = sock.socket(sock.AF_INET, sock.SOCK_DGRAM)
rtp_fd.bind(('127.0.0.1', rtp_port))
rtp_fd.settimeout(1.0)
sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': rtp_port, 'rtp_fd': rtp_fd}
rtp_thread = threading.Thread(target = run_rtp_forwarder, args = (stream_path,), daemon = True)
rtp_thread.start()
return rtp_port
def run_rtp_forwarder(stream_path : str) -> None:
session = sessions.get(stream_path)
if not session:
return
rtp_fd = session.get('rtp_fd')
while running and session.get('rtp_fd'):
try:
data, addr = rtp_fd.recvfrom(262144)
if len(data) < 2:
continue
tag = data[0]
payload = data[1:]
if tag == 0x01:
send_to_viewers(stream_path, payload)
if tag == 0x02:
send_audio_to_viewers(stream_path, payload)
except Exception:
continue
def send_audio_to_viewers(stream_path : str, opus_data : bytes) -> None:
global audio_pts
session = sessions.get(stream_path)
if not session:
return
viewers = session.get('viewers')
if not viewers:
return
buf = ctypes.create_string_buffer(opus_data)
for viewer in viewers:
if not viewer.get('connected'):
continue
audio_track_id = viewer.get('audio_track')
if not audio_track_id:
continue
if not lib.rtcIsOpen(audio_track_id):
continue
lib.rtcSetTrackRtpTimestamp(audio_track_id, audio_pts & 0xFFFFFFFF)
lib.rtcSendMessage(audio_track_id, buf, len(opus_data))
audio_pts += OPUS_FRAME_SAMPLES
sessions[stream_path] = {'viewers': []}
send_start_time : float = 0
video_frame_count : int = 0
audio_pts : int = 0
opus_enc = None
audio_buffer : bytearray = bytearray()
@@ -301,7 +186,7 @@ OPUS_FRAME_SAMPLES : int = 960
def send_to_viewers(stream_path : str, data : bytes) -> None:
global video_frame_count
global send_start_time
session = sessions.get(stream_path)
@@ -313,8 +198,11 @@ def send_to_viewers(stream_path : str, data : bytes) -> None:
if not viewers:
return
timestamp = int(video_frame_count * 3000) & 0xFFFFFFFF
video_frame_count += 1
if send_start_time == 0:
send_start_time = time.monotonic()
elapsed = time.monotonic() - send_start_time
timestamp = int(elapsed * 90000) & 0xFFFFFFFF
buf = ctypes.create_string_buffer(data)
data_len = len(data)
@@ -370,10 +258,6 @@ def encode_opus_frame(pcm_data : bytes) -> Optional[bytes]:
return None
def get_opus_encoder() -> None:
init_opus_encoder()
def send_audio(stream_path : str, pcm_data : bytes) -> None:
global audio_pts
@@ -422,112 +306,9 @@ def send_audio(stream_path : str, pcm_data : bytes) -> None:
audio_pts += OPUS_FRAME_SAMPLES
h264_au_buffer : Dict[str, bytes] = {}
def send_vp8_frame(stream_path : str, frame_data : bytes) -> None:
send_h264_frame(stream_path, frame_data)
def find_nal_starts(data : bytes) -> List:
starts = []
i = 0
while i < len(data) - 3:
if data[i] == 0 and data[i + 1] == 0:
if data[i + 2] == 1:
starts.append((i, 3))
i += 3
continue
if i < len(data) - 4 and data[i + 2] == 0 and data[i + 3] == 1:
starts.append((i, 4))
i += 4
continue
i += 1
return starts
def send_h264_frame(stream_path : str, frame_data : bytes) -> None:
global send_start_time
session = sessions.get(stream_path)
if not session:
return
viewers = session.get('viewers')
if not viewers:
return
prev = h264_au_buffer.get(stream_path, b'')
buf = prev + frame_data
nal_starts = find_nal_starts(buf)
if len(nal_starts) < 2:
h264_au_buffer[stream_path] = buf
return
au_boundaries = []
for idx, (pos, sc_len) in enumerate(nal_starts):
nal_type = buf[pos + sc_len] & 0x1f
if nal_type == 7:
au_boundaries.append(idx)
if len(au_boundaries) < 2:
h264_au_buffer[stream_path] = buf
return
if send_start_time == 0:
send_start_time = _time.monotonic()
elapsed = _time.monotonic() - send_start_time
frame_duration = 1.0 / 30.0
for k in range(len(au_boundaries) - 1):
start_nal = au_boundaries[k]
end_nal = au_boundaries[k + 1]
timestamp = int((elapsed + k * frame_duration) * 90000) & 0xFFFFFFFF
nalu_parts = []
for nal_idx in range(start_nal, end_nal):
nal_pos = nal_starts[nal_idx][0]
nal_sc_len = nal_starts[nal_idx][1]
if nal_idx + 1 < len(nal_starts):
nal_end = nal_starts[nal_idx + 1][0]
else:
nal_end = len(buf)
nalu = buf[nal_pos + nal_sc_len:nal_end]
if len(nalu) > 0:
nalu_parts.append(len(nalu).to_bytes(4, 'big') + nalu)
if nalu_parts:
frame_msg = b''.join(nalu_parts)
for viewer in viewers:
tracks = viewer.get('tracks', [])
if tracks:
lib.rtcSendMessage(tracks[0], frame_msg, len(frame_msg))
last_boundary = au_boundaries[-1]
h264_au_buffer[stream_path] = buf[nal_starts[last_boundary][0]:]
def destroy_session(stream_path : str) -> None:
session = sessions.pop(stream_path, None)
if not session:
return
for viewer in session.get('viewers', []):
pc_id = viewer.get('pc')
@@ -535,26 +316,8 @@ def destroy_session(stream_path : str) -> None:
lib.rtcDeletePeerConnection(pc_id)
def send_data(stream_path : str, data : bytes) -> None:
session = sessions.get(stream_path)
if not session:
return
for viewer in session.get('viewers', []):
for track_id in viewer.get('tracks', []):
lib.rtcSendMessage(track_id, data, len(data))
def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]:
session = sessions.get(stream_path)
if not session:
return None
if not lib:
return None
pc = create_peer_connection()
gathering_done = threading.Event()
local_sdp_holder = [None]
@@ -642,73 +405,9 @@ def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]:
def start() -> None:
global running, http_thread
if running:
return
if not load_library():
return
running = True
http_thread = threading.Thread(target = run_http_server, daemon = True)
http_thread.start()
logger.info('rtc whep server started on port ' + str(WHEP_PORT), __name__)
load_library()
def stop() -> None:
global running
running = False
for stream_path in list(sessions.keys()):
destroy_session(stream_path)
def run_http_server() -> None:
class WhepHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def send_cors_headers(self) -> None:
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_cors_headers()
self.end_headers()
def do_POST(self) -> None:
path = self.path
if not path.endswith('/whep'):
self.send_response(404)
self.send_cors_headers()
self.end_headers()
return
stream_path = path[1:].rsplit('/whep', 1)[0]
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
answer = handle_whep_offer(stream_path, body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_cors_headers()
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
return
self.send_response(404)
self.send_cors_headers()
self.end_headers()
server = HTTPServer(('0.0.0.0', WHEP_PORT), WhepHandler)
server.timeout = 1
while running:
server.handle_request()
-546
View File
@@ -1,546 +0,0 @@
import binascii
import hashlib
import os
import socket
import struct
import threading
import time
from typing import Dict, Optional, Tuple, TypeAlias
import pylibsrtp
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from OpenSSL import SSL
from facefusion import logger
SrtpSession : TypeAlias = pylibsrtp.Session
SrtpPolicy : TypeAlias = pylibsrtp.Policy
WHIP_PORT : int = 8890
ICE_UFRAG_LENGTH : int = 4
ICE_PWD_LENGTH : int = 22
RTP_HEADER_SIZE : int = 12
SRTP_PROFILES =\
[
{
'name': b'SRTP_AES128_CM_SHA1_80',
'libsrtp': SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80,
'key_len': 16,
'salt_len': 14
}
]
sessions : Dict[str, dict] = {}
server_cert = None
server_key = None
server_fingerprint : str = ''
udp_socket : Optional[socket.socket] = None
http_thread : Optional[threading.Thread] = None
udp_thread : Optional[threading.Thread] = None
running : bool = False
def generate_credentials() -> None:
global server_cert, server_key, server_fingerprint
server_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, binascii.hexlify(os.urandom(16)).decode())])
import datetime
now = datetime.datetime.now(tz = datetime.timezone.utc)
builder = x509.CertificateBuilder().subject_name(name).issuer_name(name).public_key(server_key.public_key()).serial_number(x509.random_serial_number()).not_valid_before(now - datetime.timedelta(days = 1)).not_valid_after(now + datetime.timedelta(days = 30))
server_cert = builder.sign(server_key, hashes.SHA256(), default_backend())
fp = server_cert.fingerprint(hashes.SHA256()).hex().upper()
server_fingerprint = ':'.join(fp[i:i + 2] for i in range(0, len(fp), 2))
def generate_ice_credentials() -> Tuple[str, str]:
ufrag = binascii.hexlify(os.urandom(ICE_UFRAG_LENGTH)).decode()
pwd = binascii.hexlify(os.urandom(ICE_PWD_LENGTH)).decode()[:ICE_PWD_LENGTH]
return ufrag, pwd
def parse_sdp_offer(sdp : str) -> dict:
result = {'ice_ufrag': '', 'ice_pwd': '', 'fingerprint': '', 'setup': '', 'media': [], 'candidates': []}
current_media = None
for line in sdp.splitlines():
line = line.strip()
if line.startswith('a=ice-ufrag:'):
result['ice_ufrag'] = line.split(':', 1)[1]
if line.startswith('a=ice-pwd:'):
result['ice_pwd'] = line.split(':', 1)[1]
if line.startswith('a=fingerprint:'):
result['fingerprint'] = line.split(' ', 1)[1] if ' ' in line else ''
if line.startswith('a=setup:'):
result['setup'] = line.split(':', 1)[1]
if line.startswith('a=candidate:'):
result['candidates'].append(line[12:])
if line.startswith('m='):
parts = line[2:].split()
current_media = {'kind': parts[0], 'port': int(parts[1]), 'profile': parts[2], 'formats': parts[3:], 'codec_lines': [], 'mid': None}
result['media'].append(current_media)
if current_media:
if line.startswith('a=rtpmap:') or line.startswith('a=fmtp:') or line.startswith('a=rtcp-fb:'):
current_media['codec_lines'].append(line)
if line.startswith('a=mid:'):
current_media['mid'] = line.split(':', 1)[1]
if line.startswith('a=extmap:'):
current_media['codec_lines'].append(line)
return result
def build_sdp_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int) -> str:
lines = []
lines.append('v=0')
lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1')
lines.append('s=-')
lines.append('t=0 0')
mids = []
for i, media in enumerate(offer.get('media', [])):
mids.append(str(i))
if mids:
lines.append('a=group:BUNDLE ' + ' '.join(mids))
lines.append('a=ice-lite')
for i, media in enumerate(offer.get('media', [])):
kind = media.get('kind')
formats = media.get('formats', [])
profile = media.get('profile', 'UDP/TLS/RTP/SAVPF')
mid = media.get('mid', str(i))
lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats))
lines.append('c=IN IP4 127.0.0.1')
lines.append('a=rtcp:9 IN IP4 0.0.0.0')
lines.append('a=ice-ufrag:' + local_ufrag)
lines.append('a=ice-pwd:' + local_pwd)
lines.append('a=ice-options:ice2')
lines.append('a=fingerprint:sha-256 ' + server_fingerprint)
lines.append('a=setup:passive')
lines.append('a=mid:' + mid)
lines.append('a=rtcp-mux')
lines.append('a=recvonly')
for codec_line in media.get('codec_lines', []):
lines.append(codec_line)
lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host')
lines.append('a=end-of-candidates')
return '\r\n'.join(lines) + '\r\n'
def build_whep_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int, ingest_offer : dict) -> str:
lines = []
lines.append('v=0')
lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1')
lines.append('s=-')
lines.append('t=0 0')
mids = []
for i, media in enumerate(offer.get('media', [])):
mid = media.get('mid', str(i))
mids.append(mid)
if mids:
lines.append('a=group:BUNDLE ' + ' '.join(mids))
lines.append('a=ice-lite')
for i, media in enumerate(offer.get('media', [])):
kind = media.get('kind')
formats = media.get('formats', [])
profile = media.get('profile', 'UDP/TLS/RTP/SAVPF')
mid = media.get('mid', str(i))
lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats))
lines.append('c=IN IP4 127.0.0.1')
lines.append('a=rtcp:9 IN IP4 0.0.0.0')
lines.append('a=ice-ufrag:' + local_ufrag)
lines.append('a=ice-pwd:' + local_pwd)
lines.append('a=ice-options:ice2')
lines.append('a=fingerprint:sha-256 ' + server_fingerprint)
lines.append('a=setup:passive')
lines.append('a=mid:' + mid)
lines.append('a=rtcp-mux')
lines.append('a=sendonly')
for codec_line in media.get('codec_lines', []):
lines.append(codec_line)
lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host')
lines.append('a=end-of-candidates')
return '\r\n'.join(lines) + '\r\n'
def create_ssl_context() -> SSL.Context:
ctx = SSL.Context(SSL.DTLS_METHOD)
ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: True)
ctx.use_certificate(server_cert)
ctx.use_privatekey(server_key)
ctx.set_cipher_list(b'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-SHA')
ctx.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80')
return ctx
def create_session(stream_path : str) -> None:
ufrag, pwd = generate_ice_credentials()
sessions[stream_path] = {
'ice_ufrag': ufrag,
'ice_pwd': pwd,
'ingest': None,
'viewers': [],
'ingest_offer': None,
'rx_srtp': None,
'tx_sessions': []
}
def destroy_session(stream_path : str) -> None:
sessions.pop(stream_path, None)
def handle_whip(stream_path : str, sdp_offer : str) -> Optional[str]:
session = sessions.get(stream_path)
if not session:
return None
offer = parse_sdp_offer(sdp_offer)
session['ingest_offer'] = offer
local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT
answer = build_sdp_answer(offer, session.get('ice_ufrag'), session.get('ice_pwd'), local_port)
return answer
def handle_whep(stream_path : str, sdp_offer : str) -> Optional[str]:
session = sessions.get(stream_path)
if not session:
return None
offer = parse_sdp_offer(sdp_offer)
viewer_ufrag, viewer_pwd = generate_ice_credentials()
local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT
ingest_offer = session.get('ingest_offer', offer)
answer = build_whep_answer(offer, viewer_ufrag, viewer_pwd, local_port, ingest_offer)
session['viewers'].append({'offer': offer, 'ice_ufrag': viewer_ufrag, 'ice_pwd': viewer_pwd})
return answer
def start() -> None:
global running, udp_socket, http_thread
generate_credentials()
running = True
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
udp_socket.bind(('0.0.0.0', WHIP_PORT))
udp_socket.settimeout(1.0)
udp_thread_instance = threading.Thread(target = run_udp_loop, daemon = True)
udp_thread_instance.start()
http_thread = threading.Thread(target = run_http_server, daemon = True)
http_thread.start()
logger.info('webrtc sfu started on port ' + str(WHIP_PORT), __name__)
def stop() -> None:
global running, udp_socket
running = False
if udp_socket:
udp_socket.close()
udp_socket = None
dtls_connections : Dict[tuple, dict] = {}
def run_udp_loop() -> None:
while running:
try:
data, addr = udp_socket.recvfrom(2048)
if not data:
continue
first_byte = data[0]
if first_byte == 0 or first_byte == 1:
handle_stun(data, addr)
if first_byte > 19 and first_byte < 64:
handle_dtls(data, addr)
if first_byte > 127 and first_byte < 192:
handle_srtp(data, addr)
except socket.timeout:
continue
except Exception:
if running:
continue
def handle_dtls(data : bytes, addr : tuple) -> None:
conn = dtls_connections.get(addr)
if not conn:
ctx = create_ssl_context()
ssl_conn = SSL.Connection(ctx)
ssl_conn.set_accept_state()
conn = {'ssl': ssl_conn, 'encrypted': False, 'rx_srtp': None, 'tx_srtp': None}
dtls_connections[addr] = conn
ssl_conn = conn.get('ssl')
ssl_conn.bio_write(data)
try:
if not conn.get('encrypted'):
try:
ssl_conn.do_handshake()
conn['encrypted'] = True
setup_srtp_session(conn)
logger.info('dtls handshake complete with ' + str(addr), __name__)
except SSL.WantReadError:
pass
else:
try:
ssl_conn.recv(1500)
except SSL.ZeroReturnError:
pass
except SSL.Error:
pass
except Exception:
pass
flush_dtls(ssl_conn, addr)
def flush_dtls(ssl_conn : SSL.Connection, addr : tuple) -> None:
try:
outdata = ssl_conn.bio_read(1500)
if outdata:
udp_socket.sendto(outdata, addr)
except SSL.Error:
pass
def setup_srtp_session(conn : dict) -> None:
ssl_conn = conn.get('ssl')
ssl_conn.get_selected_srtp_profile()
key_len = 16
salt_len = 14
view = ssl_conn.export_keying_material(b'EXTRACTOR-dtls_srtp', 2 * (key_len + salt_len))
server_key = view[key_len:2 * key_len] + view[2 * key_len + salt_len:]
client_key = view[:key_len] + view[2 * key_len:2 * key_len + salt_len]
rx_policy = SrtpPolicy(key = client_key, ssrc_type = SrtpPolicy.SSRC_ANY_INBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80)
rx_policy.allow_repeat_tx = True
rx_policy.window_size = 1024
conn['rx_srtp'] = SrtpSession(rx_policy)
tx_policy = SrtpPolicy(key = server_key, ssrc_type = SrtpPolicy.SSRC_ANY_OUTBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80)
tx_policy.allow_repeat_tx = True
tx_policy.window_size = 1024
conn['tx_srtp'] = SrtpSession(tx_policy)
def is_rtcp(data : bytes) -> bool:
if len(data) < 2:
return False
pt = data[1] & 0x7F
return 64 <= pt <= 95
def handle_srtp(data : bytes, addr : tuple) -> None:
conn = dtls_connections.get(addr)
if not conn or not conn.get('rx_srtp'):
return
try:
if is_rtcp(data):
plain = conn.get('rx_srtp').unprotect_rtcp(data)
else:
plain = conn.get('rx_srtp').unprotect(data)
forward_rtp(plain, addr)
except Exception:
pass
def forward_rtp(data : bytes, source_addr : tuple) -> None:
for other_addr, conn in dtls_connections.items():
if other_addr == source_addr:
continue
if not conn.get('tx_srtp'):
continue
try:
if is_rtcp(data):
encrypted = conn.get('tx_srtp').protect_rtcp(data)
else:
encrypted = conn.get('tx_srtp').protect(data)
udp_socket.sendto(encrypted, other_addr)
except Exception:
pass
def handle_stun(data : bytes, addr : tuple) -> None:
if len(data) < 20:
return
msg_type = struct.unpack('!H', data[0:2])[0]
if msg_type != 0x0001:
return
msg_length = struct.unpack('!H', data[2:4])[0]
transaction_id = data[8:20]
username = None
offset = 20
while offset < 20 + msg_length:
if offset + 4 > len(data):
break
attr_type = struct.unpack('!H', data[offset:offset + 2])[0]
attr_length = struct.unpack('!H', data[offset + 2:offset + 4])[0]
attr_value = data[offset + 4:offset + 4 + attr_length]
if attr_type == 0x0006:
username = attr_value.decode('utf-8', errors = 'ignore')
padded = attr_length + (4 - attr_length % 4) % 4
offset += 4 + padded
if not username:
return
local_ufrag = username.split(':')[0] if ':' in username else username
session_pwd = None
for session in sessions.values():
if session.get('ice_ufrag') == local_ufrag:
session_pwd = session.get('ice_pwd')
break
for viewer in session.get('viewers', []):
if viewer.get('ice_ufrag') == local_ufrag:
session_pwd = viewer.get('ice_pwd')
break
if session_pwd:
break
if not session_pwd:
return
response = build_stun_response(transaction_id, addr, session_pwd)
udp_socket.sendto(response, addr)
def build_stun_response(transaction_id : bytes, addr : tuple, password : str) -> bytes:
import hmac
import zlib
magic_cookie = 0x2112A442
magic_bytes = struct.pack('!I', magic_cookie)
xport = addr[1] ^ (magic_cookie >> 16)
ip_int = struct.unpack('!I', socket.inet_aton(addr[0]))[0]
xip = struct.pack('!I', ip_int ^ magic_cookie)
xor_addr_value = struct.pack('!BBH', 0, 0x01, xport) + xip
xor_addr_attr = struct.pack('!HH', 0x0020, len(xor_addr_value)) + xor_addr_value
attrs_before_integrity = xor_addr_attr
integrity_dummy_len = len(attrs_before_integrity) + 4 + 20
header_for_hmac = struct.pack('!HH', 0x0101, integrity_dummy_len) + magic_bytes + transaction_id
key = password.encode('utf-8')
integrity = hmac.new(key, header_for_hmac + attrs_before_integrity, hashlib.sha1).digest()
integrity_attr = struct.pack('!HH', 0x0008, 20) + integrity
attrs_before_fp = attrs_before_integrity + integrity_attr
fp_dummy_len = len(attrs_before_fp) + 4 + 4
header_for_fp = struct.pack('!HH', 0x0101, fp_dummy_len) + magic_bytes + transaction_id
crc = zlib.crc32(header_for_fp + attrs_before_fp) ^ 0x5354554E
fingerprint_attr = struct.pack('!HHI', 0x8028, 4, crc & 0xFFFFFFFF)
all_attrs = attrs_before_integrity + integrity_attr + fingerprint_attr
header = struct.pack('!HH', 0x0101, len(all_attrs)) + magic_bytes + transaction_id
return header + all_attrs
def run_http_server() -> None:
from http.server import HTTPServer, BaseHTTPRequestHandler
class WhipWhepHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def do_POST(self) -> None:
path = self.path
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
if path.endswith('/whip'):
stream_path = path[1:].rsplit('/whip', 1)[0]
answer = handle_whip(stream_path, body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
return
self.send_response(404)
self.end_headers()
return
if path.endswith('/whep'):
stream_path = path[1:].rsplit('/whep', 1)[0]
answer = handle_whep(stream_path, body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
return
self.send_response(404)
self.end_headers()
return
self.send_response(404)
self.end_headers()
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.end_headers()
server = HTTPServer(('0.0.0.0', WHIP_PORT), WhipWhepHandler)
server.timeout = 1
while running:
server.handle_request()
-62
View File
@@ -1,62 +0,0 @@
import threading
from typing import Optional
from facefusion import logger
RELAY_PORT : int = 8891
_started : bool = False
_lock : threading.Lock = threading.Lock()
def get_whip_url(stream_path : str) -> str:
from facefusion import rtc
return 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whip'
def get_whep_url(stream_path : str) -> str:
from facefusion import rtc
return 'http://localhost:' + str(rtc.WHEP_PORT) + '/' + stream_path + '/whep'
def start() -> None:
global _started
from facefusion import rtc
if not rtc.lib:
if not rtc.load_library():
logger.warn('whip relay: libdatachannel not available', __name__)
return
if not rtc.running:
rtc.start()
_started = True
logger.info('whip relay (python) ready on port ' + str(rtc.WHEP_PORT), __name__)
def stop() -> None:
global _started
_started = False
def wait_for_ready() -> bool:
return _started
def is_session_ready(stream_path : str) -> bool:
from facefusion import rtc
return stream_path in rtc.sessions
def create_session(stream_path : str) -> int:
from facefusion import rtc
if not _started:
start()
if not rtc.lib:
return 0
rtp_port = rtc.create_rtp_session(stream_path)
return rtp_port
-9
View File
@@ -1,9 +0,0 @@
rtsp: no
rtmp: no
hls: no
srt: no
webrtc: yes
webrtcAddress: :8889
api: yes
apiAddress: :9997
paths:
+18 -256
View File
@@ -139,24 +139,7 @@
</div>
<div class="section">
<div class="section-title"><span class="step-dot" id="dotMode">4</span> Streaming Mode</div>
<div class="form-row">
<label>Approach
<select id="streamMode">
<option value="whip-mediamtx">FFmpeg WHIP + MediaMTX</option>
<option value="whip-python">aiortc WebRTC (no ext. deps)</option>
<option value="whip-datachannel">FFmpeg WHIP + libdatachannel relay</option>
<option value="ws-fmp4">FFmpeg fMP4 + WebSocket (MSE)</option>
<option value="datachannel-direct">libdatachannel direct</option>
<option value="datachannel-relay-py">libdatachannel Python relay (UDP)</option>
<option value="ws-mjpeg">MJPEG over WebSocket (no deps)</option>
</select>
</label>
</div>
</div>
<div class="section">
<div class="section-title"><span class="step-dot" id="dotOptions">5</span> Options</div>
<div class="section-title"><span class="step-dot" id="dotOptions">4</span> Options</div>
<div class="form-row">
<label>Capture Resolution
<select id="captureRes">
@@ -180,7 +163,7 @@
</div>
<div class="section">
<div class="section-title"><span class="step-dot" id="dotStream">6</span> Stream</div>
<div class="section-title"><span class="step-dot" id="dotStream">5</span> Stream</div>
<div class="switch-row">
<span>Ready</span>
<span id="streamReadyHint" style="font-size:0.7rem;color:#555;">set source + video first</span>
@@ -303,38 +286,12 @@ var prevBytes = 0;
var prevFrames = 0;
var prevStatsTime = 0;
var prevFramesSent = 0;
var metricsWs = null;
var mediaSource = null;
var sourceBuffer = null;
var mseQueue = [];
var mseReady = false;
var captureCanvas = document.createElement('canvas');
var captureCtx = captureCanvas.getContext('2d');
var audioCtx = null;
var audioWorklet = null;
var audioEchoWs = null;
var audioPlayCtx = null;
var audioPlayNextTime = 0;
var MODE_CONFIG = {
'whip-mediamtx': { wsPath: '/stream/whip', playback: 'whep' },
'whip-python': { wsPath: '/stream/whip-py', playback: 'whep' },
'whip-datachannel': { wsPath: '/stream/whip-dc', playback: 'whep' },
'ws-fmp4': { wsPath: '/stream/live', playback: 'mse' },
'datachannel-direct': { wsPath: '/stream/rtc', playback: 'whep' },
'datachannel-relay-py': { wsPath: '/stream/rtc-relay', playback: 'whep' },
'ws-mjpeg': { wsPath: '/stream/mjpeg', playback: 'mjpeg' }
};
function getMode() {
return document.getElementById('streamMode').value;
}
function getModeConfig() {
return MODE_CONFIG[getMode()];
}
function log(msg, type) {
type = type || 'info';
@@ -358,10 +315,6 @@ function wsBase() {
return base().replace(/^http/, 'ws');
}
function whepUrl() {
return whepUrlFromServer;
}
function authHeaders() {
return { 'Authorization': 'Bearer ' + accessToken };
}
@@ -753,12 +706,6 @@ function onSeekCommit() {
timelineVideo.currentTime = t;
log('seek → ' + formatTime(t), 'info');
}
if (audioPlayCtx) {
audioPlayCtx.close();
audioPlayCtx = new AudioContext({ sampleRate: 48000 });
audioPlayNextTime = 0;
}
}
function formatTime(s) {
@@ -797,19 +744,11 @@ function captureAndSend() {
}, 'image/jpeg', 0.7);
}
async function connectWhep() {
var url = whepUrl();
async function connectWhep(url) {
var t0 = performance.now();
log('WHEP → ' + url, 'info');
var PeerConnection = window.RTCPeerConnection || window.webkitRTCPeerConnection || window.mozRTCPeerConnection;
if (!PeerConnection) {
log('WebRTC not supported in this browser', 'error');
return;
}
pc = new PeerConnection({ iceServers: [] });
pc = new RTCPeerConnection({ iceServers: [] });
pc.onconnectionstatechange = function() {
var state = pc.connectionState;
@@ -899,156 +838,19 @@ function stopAudioCapture() {
}
}
function startAudioEcho() {
var stream = localStream;
if (!stream || stream.getAudioTracks().length === 0) {
log('no audio track for echo', 'warn');
return;
}
audioPlayCtx = new AudioContext({ sampleRate: 48000 });
audioPlayNextTime = 0;
var captureCtxAudio = new AudioContext({ sampleRate: 48000 });
var source = captureCtxAudio.createMediaStreamSource(stream);
var processor = captureCtxAudio.createScriptProcessor(4096, 2, 2);
var echoUrl = wsBase() + '/stream/audio';
var protocols = ['access_token.' + accessToken];
audioEchoWs = new WebSocket(echoUrl, protocols);
audioEchoWs.binaryType = 'arraybuffer';
audioEchoWs.onmessage = function(event) {
if (!audioPlayCtx) return;
var pcm = new Int16Array(event.data);
var samples = pcm.length / 2;
var buffer = audioPlayCtx.createBuffer(2, samples, 48000);
var left = buffer.getChannelData(0);
var right = buffer.getChannelData(1);
for (var i = 0; i < samples; i++) {
left[i] = pcm[i * 2] / 32768;
right[i] = pcm[i * 2 + 1] / 32768;
}
var bufferSource = audioPlayCtx.createBufferSource();
bufferSource.buffer = buffer;
bufferSource.connect(audioPlayCtx.destination);
var now = audioPlayCtx.currentTime;
if (audioPlayNextTime < now) audioPlayNextTime = now + 0.05;
bufferSource.start(audioPlayNextTime);
audioPlayNextTime += buffer.duration;
};
processor.onaudioprocess = function(e) {
if (!audioEchoWs || audioEchoWs.readyState !== WebSocket.OPEN) return;
var left = e.inputBuffer.getChannelData(0);
var right = e.inputBuffer.getChannelData(1);
var pcm = new Int16Array(left.length * 2);
for (var i = 0; i < left.length; i++) {
pcm[i * 2] = Math.max(-32768, Math.min(32767, left[i] * 32768));
pcm[i * 2 + 1] = Math.max(-32768, Math.min(32767, right[i] * 32768));
}
audioEchoWs.send(pcm.buffer);
};
source.connect(processor);
processor.connect(captureCtxAudio.destination);
log('audio echo started (48kHz stereo s16le)', 'ok');
}
function stopAudioEcho() {
if (audioEchoWs) {
audioEchoWs.close();
audioEchoWs = null;
}
if (audioPlayCtx) {
audioPlayCtx.close();
audioPlayCtx = null;
}
audioPlayNextTime = 0;
}
function initMse() {
var video = document.getElementById('outputVideo');
mediaSource = new MediaSource();
video.src = URL.createObjectURL(mediaSource);
mediaSource.addEventListener('sourceopen', function() {
sourceBuffer = mediaSource.addSourceBuffer('video/mp4; codecs="avc1.42E01E,mp4a.40.2"');
sourceBuffer.mode = 'sequence';
mseReady = true;
sourceBuffer.addEventListener('updateend', function() {
if (mseQueue.length > 0 && !sourceBuffer.updating) {
sourceBuffer.appendBuffer(mseQueue.shift());
}
});
log('MSE source buffer ready', 'ok');
});
}
function feedMse(data) {
if (!mseReady || !sourceBuffer) return;
if (sourceBuffer.updating || mseQueue.length > 0) {
mseQueue.push(data);
} else {
sourceBuffer.appendBuffer(data);
}
}
function cleanupMse() {
mseQueue = [];
mseReady = false;
sourceBuffer = null;
if (mediaSource && mediaSource.readyState === 'open') {
mediaSource.endOfStream();
}
mediaSource = null;
}
async function connect() {
var config = getModeConfig();
var mode = getMode();
framesSent = 0;
whepUrlFromServer = null;
var outputVideo = document.getElementById('outputVideo');
outputVideo.srcObject = null;
outputVideo.removeAttribute('src');
outputVideo.load();
outputVideo.style.display = '';
var mjpegImg = outputVideo._mjpegImg;
if (mjpegImg) {
if (mjpegImg._prevUrl) URL.revokeObjectURL(mjpegImg._prevUrl);
mjpegImg.remove();
outputVideo._mjpegImg = null;
}
cleanupMse();
whepUrlFromServer = null;
var wsUrl = wsBase() + config.wsPath;
var wsUrl = wsBase() + '/stream/rtc';
var protocols = ['access_token.' + accessToken];
var t0 = performance.now();
log('[' + mode + '] ws → ' + wsUrl, 'info');
if (config.playback === 'mse') {
initMse();
}
log('ws → ' + wsUrl, 'info');
ws = new WebSocket(wsUrl, protocols);
ws.binaryType = 'arraybuffer';
@@ -1056,7 +858,6 @@ async function connect() {
ws.onopen = function() {
log('websocket open (' + Math.round(performance.now() - t0) + 'ms) — sending frames', 'ok');
markDone('dotStream');
markDone('dotMode');
document.getElementById('btnPlay').disabled = true;
document.getElementById('btnPlay').classList.add('active');
document.getElementById('btnStop').disabled = false;
@@ -1072,67 +873,31 @@ async function connect() {
updateTrackVisual(0, timelineVideo ? timelineVideo.duration : 0);
captureTimer = setInterval(captureAndSend, 1000 / 30);
if (config.playback === 'mjpeg') {
startAudioEcho();
} else {
startAudioCapture();
}
startAudioCapture();
startStats();
};
var streamStarted = false;
function onFirstOutput() {
if (streamStarted) return;
streamStarted = true;
if (timelineVideo) timelineVideo.play();
startTimelineSync();
log('stream output started', 'ok');
}
ws.onmessage = function(event) {
if (config.playback === 'mse' && event.data instanceof ArrayBuffer) {
onFirstOutput();
feedMse(event.data);
return;
}
if (config.playback === 'mjpeg' && event.data instanceof ArrayBuffer) {
onFirstOutput();
var blob = new Blob([event.data], { type: 'image/jpeg' });
var url = URL.createObjectURL(blob);
var video = document.getElementById('outputVideo');
if (!video._mjpegImg) {
video.style.display = 'none';
var img = document.createElement('img');
img.id = 'mjpegOutput';
img.style.cssText = 'width:100%;height:100%;object-fit:contain;border-radius:8px;';
video.parentNode.appendChild(img);
video._mjpegImg = img;
}
if (video._mjpegImg._prevUrl) URL.revokeObjectURL(video._mjpegImg._prevUrl);
video._mjpegImg.src = url;
video._mjpegImg._prevUrl = url;
return;
}
if (typeof event.data === 'string' && !whepUrlFromServer) {
whepUrlFromServer = event.data;
onFirstOutput();
whepUrlFromServer = base() + event.data;
if (!streamStarted) {
streamStarted = true;
if (timelineVideo) timelineVideo.play();
startTimelineSync();
log('stream output started', 'ok');
}
log('stream ready (' + Math.round(performance.now() - t0) + 'ms) — WHEP url: ' + whepUrlFromServer, 'ok');
if (!window.RTCPeerConnection && !window.webkitRTCPeerConnection) {
log('WebRTC not supported — try Chrome or Edge', 'error');
return;
}
var tWhep = performance.now();
connectWhep().then(function() {
connectWhep(whepUrlFromServer).then(function() {
log('WHEP connected (' + Math.round(performance.now() - tWhep) + 'ms)', 'ok');
}).catch(function(e) {
log('WHEP failed (' + Math.round(performance.now() - tWhep) + 'ms): ' + e.message, 'error');
});
return;
}
};
@@ -1150,7 +915,6 @@ function stopStreaming() {
streaming = false;
updatePipVisibility();
stopStats();
cleanupMse();
if (captureTimer) {
clearInterval(captureTimer);
@@ -1162,12 +926,10 @@ function stopStreaming() {
document.getElementById('btnStop').disabled = true;
document.getElementById('timeSlider').disabled = true;
document.getElementById('dotStream').classList.remove('done');
document.getElementById('dotMode').classList.remove('done');
}
function disconnect() {
stopAudioCapture();
stopAudioEcho();
stopTimelineSync();
if (pc) {
-37
View File
@@ -97,40 +97,3 @@ def test_stream_image(test_client : TestClient) -> None:
output_vision_frame = cv2.imdecode(numpy.frombuffer(output_bytes, numpy.uint8), cv2.IMREAD_COLOR)
assert output_vision_frame.shape == (1024, 1024, 3)
def test_stream_whip(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
access_token = create_session_response.json().get('access_token')
source_path = get_test_example_file('source.jpg')
with open(source_path, 'rb') as source_file:
source_content = source_file.read()
upload_response = test_client.post('/assets?type=source', headers =
{
'Authorization': 'Bearer ' + access_token
}, files =
[
('file', ('source.jpg', source_content, 'image/jpeg'))
])
asset_id = upload_response.json().get('asset_ids')[0]
test_client.put('/state?action=select&type=source', json =
{
'asset_ids': [ asset_id ]
}, headers =
{
'Authorization': 'Bearer ' + access_token
})
with test_client.websocket_connect('/stream/whip', subprotocols =
[
'access_token.' + access_token
]) as websocket:
websocket.send_bytes(source_content)
assert True
BIN
View File
Binary file not shown.
-619
View File
@@ -1,619 +0,0 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <pthread.h>
#include <sys/time.h>
#include <signal.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <rtc/rtc.h>
#define MAX_SESSIONS 16
#define MAX_VIEWERS 8
#define MAX_TRACKS 4
#define SDP_BUF_SIZE 16384
#define HTTP_BUF_SIZE 65536
typedef struct
{
int pc;
int tracks[MAX_TRACKS];
int track_count;
int audio_track;
int connected;
} Viewer;
typedef struct
{
char path[256];
int rtp_port;
int rtp_fd;
Viewer viewers[MAX_VIEWERS];
int viewer_count;
int active;
uint32_t audio_pts;
pthread_mutex_t lock;
} Session;
typedef struct
{
Session *session;
int viewer_index;
int gathering_done;
pthread_mutex_t gather_lock;
pthread_cond_t gather_cond;
} ViewerCtx;
static Session sessions[MAX_SESSIONS];
static pthread_mutex_t sessions_lock = PTHREAD_MUTEX_INITIALIZER;
static volatile int running = 1;
static int next_rtp_port = 15000;
static Session *find_session(const char *path)
{
for (int i = 0; i < MAX_SESSIONS; i++)
{
if (sessions[i].active && strcmp(sessions[i].path, path) == 0)
{
return &sessions[i];
}
}
return NULL;
}
static double get_elapsed_seconds(struct timeval *start)
{
struct timeval now;
gettimeofday(&now, NULL);
return (now.tv_sec - start->tv_sec) + (now.tv_usec - start->tv_usec) / 1000000.0;
}
static void *receiver_thread(void *arg)
{
Session *session = (Session *)arg;
char buf[256 * 1024];
struct timeval start_time;
int started = 0;
while (running && session->active)
{
struct sockaddr_in from;
socklen_t fromlen = sizeof(from);
int n = recvfrom(session->rtp_fd, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
if (n <= 1)
{
continue;
}
char tag = buf[0];
char *payload = buf + 1;
int payload_len = n - 1;
if (!started && tag == 0x01)
{
gettimeofday(&start_time, NULL);
started = 1;
}
pthread_mutex_lock(&session->lock);
for (int v = 0; v < session->viewer_count; v++)
{
Viewer *viewer = &session->viewers[v];
if (!viewer->connected)
{
continue;
}
if (tag == 0x01)
{
for (int t = 0; t < viewer->track_count; t++)
{
if (!rtcIsOpen(viewer->tracks[t]))
{
continue;
}
double elapsed = started ? get_elapsed_seconds(&start_time) : 0;
uint32_t timestamp = (uint32_t)(elapsed * 90000.0);
rtcSetTrackRtpTimestamp(viewer->tracks[t], timestamp);
rtcSendMessage(viewer->tracks[t], payload, payload_len);
}
}
if (tag == 0x02 && viewer->audio_track > 0)
{
if (rtcIsOpen(viewer->audio_track))
{
rtcSetTrackRtpTimestamp(viewer->audio_track, session->audio_pts);
rtcSendMessage(viewer->audio_track, payload, payload_len);
}
session->audio_pts += 960;
}
}
pthread_mutex_unlock(&session->lock);
}
return NULL;
}
static Session *create_session_slot(const char *path)
{
for (int i = 0; i < MAX_SESSIONS; i++)
{
if (!sessions[i].active)
{
memset(&sessions[i], 0, sizeof(Session));
strncpy(sessions[i].path, path, sizeof(sessions[i].path) - 1);
sessions[i].active = 1;
sessions[i].rtp_port = next_rtp_port++;
pthread_mutex_init(&sessions[i].lock, NULL);
int fd = socket(AF_INET, SOCK_DGRAM, 0);
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = inet_addr("127.0.0.1");
addr.sin_port = htons(sessions[i].rtp_port);
if (bind(fd, (struct sockaddr *)&addr, sizeof(addr)) < 0)
{
perror("bind rtp");
close(fd);
sessions[i].active = 0;
return NULL;
}
struct timeval tv;
tv.tv_sec = 1;
tv.tv_usec = 0;
setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
sessions[i].rtp_fd = fd;
sessions[i].audio_pts = 0;
pthread_t tid;
pthread_create(&tid, NULL, receiver_thread, &sessions[i]);
pthread_detach(tid);
return &sessions[i];
}
}
return NULL;
}
static ViewerCtx viewer_contexts[MAX_SESSIONS * MAX_VIEWERS];
static int viewer_ctx_count = 0;
static void on_viewer_state(int pc, rtcState state, void *ptr)
{
ViewerCtx *ctx = (ViewerCtx *)ptr;
if (state == RTC_CONNECTED && ctx->session)
{
pthread_mutex_lock(&ctx->session->lock);
if (ctx->viewer_index < ctx->session->viewer_count)
{
ctx->session->viewers[ctx->viewer_index].connected = 1;
}
pthread_mutex_unlock(&ctx->session->lock);
}
}
static void on_viewer_gathering(int pc, rtcGatheringState state, void *ptr)
{
ViewerCtx *ctx = (ViewerCtx *)ptr;
if (state == RTC_GATHERING_COMPLETE)
{
pthread_mutex_lock(&ctx->gather_lock);
ctx->gathering_done = 1;
pthread_cond_signal(&ctx->gather_cond);
pthread_mutex_unlock(&ctx->gather_lock);
}
}
static int create_viewer_pc(Session *session, const char *offer_sdp, char *answer_buf, int answer_size)
{
if (session->viewer_count >= MAX_VIEWERS)
{
return -1;
}
rtcConfiguration config;
memset(&config, 0, sizeof(config));
config.forceMediaTransport = true;
config.enableIceUdpMux = true;
int pc = rtcCreatePeerConnection(&config);
if (pc < 0)
{
return -1;
}
ViewerCtx *ctx = &viewer_contexts[viewer_ctx_count++];
ctx->session = session;
ctx->viewer_index = session->viewer_count;
ctx->gathering_done = 0;
pthread_mutex_init(&ctx->gather_lock, NULL);
pthread_cond_init(&ctx->gather_cond, NULL);
Viewer *viewer = &session->viewers[session->viewer_count];
viewer->pc = pc;
viewer->connected = 0;
viewer->track_count = 0;
rtcSetUserPointer(pc, ctx);
rtcSetGatheringStateChangeCallback(pc, on_viewer_gathering);
rtcSetStateChangeCallback(pc, on_viewer_state);
int video_track = rtcAddTrack(pc,
"m=video 9 UDP/TLS/RTP/SAVPF 96\r\n"
"a=rtpmap:96 VP8/90000\r\n"
"a=sendonly\r\na=mid:0\r\na=rtcp-mux\r\n");
if (video_track >= 0)
{
rtcPacketizerInit packetizer;
memset(&packetizer, 0, sizeof(packetizer));
packetizer.ssrc = 42;
packetizer.cname = "video";
packetizer.payloadType = 96;
packetizer.clockRate = 90000;
packetizer.maxFragmentSize = 1200;
rtcSetVP8Packetizer(video_track, &packetizer);
rtcChainRtcpSrReporter(video_track);
rtcChainRtcpNackResponder(video_track, 512);
viewer->tracks[viewer->track_count++] = video_track;
}
int audio_track = rtcAddTrack(pc,
"m=audio 9 UDP/TLS/RTP/SAVPF 111\r\n"
"a=rtpmap:111 opus/48000/2\r\n"
"a=sendonly\r\na=mid:1\r\na=rtcp-mux\r\n");
if (audio_track >= 0)
{
rtcPacketizerInit audio_packetizer;
memset(&audio_packetizer, 0, sizeof(audio_packetizer));
audio_packetizer.ssrc = 43;
audio_packetizer.cname = "audio";
audio_packetizer.payloadType = 111;
audio_packetizer.clockRate = 48000;
rtcSetOpusPacketizer(audio_track, &audio_packetizer);
rtcChainRtcpSrReporter(audio_track);
viewer->audio_track = audio_track;
}
rtcSetRemoteDescription(pc, offer_sdp, "offer");
struct timespec ts;
clock_gettime(CLOCK_REALTIME, &ts);
ts.tv_sec += 5;
pthread_mutex_lock(&ctx->gather_lock);
while (!ctx->gathering_done)
{
if (pthread_cond_timedwait(&ctx->gather_cond, &ctx->gather_lock, &ts) != 0)
{
break;
}
}
pthread_mutex_unlock(&ctx->gather_lock);
int len = rtcGetLocalDescription(pc, answer_buf, answer_size);
if (len < 0)
{
return -1;
}
pthread_mutex_lock(&session->lock);
session->viewer_count++;
pthread_mutex_unlock(&session->lock);
return 0;
}
static void parse_http_request(const char *buf, int len, char *method, char *path, char *body, int *body_len)
{
method[0] = 0;
path[0] = 0;
body[0] = 0;
*body_len = 0;
sscanf(buf, "%15s %255s", method, path);
const char *body_start = strstr(buf, "\r\n\r\n");
if (body_start)
{
body_start += 4;
*body_len = len - (body_start - buf);
if (*body_len > 0)
{
memcpy(body, body_start, *body_len);
body[*body_len] = 0;
}
}
}
static void send_http_response(int fd, int status, const char *content_type, const char *body, int body_len)
{
char header[1024];
const char *status_text = status == 201 ? "Created" : status == 200 ? "OK" : "Not Found";
int hlen = snprintf(header, sizeof(header),
"HTTP/1.1 %d %s\r\n"
"Content-Type: %s\r\n"
"Content-Length: %d\r\n"
"Access-Control-Allow-Origin: *\r\n"
"Access-Control-Allow-Methods: POST, DELETE, OPTIONS, GET\r\n"
"Access-Control-Allow-Headers: Content-Type\r\n"
"Connection: close\r\n"
"\r\n",
status, status_text, content_type, body_len);
write(fd, header, hlen);
if (body_len > 0)
{
write(fd, body, body_len);
}
}
static void handle_client(int client_fd)
{
char buf[HTTP_BUF_SIZE];
int total = 0;
int n;
while (total < HTTP_BUF_SIZE - 1)
{
n = read(client_fd, buf + total, HTTP_BUF_SIZE - 1 - total);
if (n <= 0)
{
break;
}
total += n;
if (strstr(buf, "\r\n\r\n"))
{
int content_length = 0;
char *cl = strstr(buf, "Content-Length:");
if (!cl)
{
cl = strstr(buf, "content-length:");
}
if (cl)
{
content_length = atoi(cl + 15);
}
char *body_start = strstr(buf, "\r\n\r\n") + 4;
int header_len = body_start - buf;
int body_so_far = total - header_len;
while (body_so_far < content_length && total < HTTP_BUF_SIZE - 1)
{
n = read(client_fd, buf + total, HTTP_BUF_SIZE - 1 - total);
if (n <= 0)
{
break;
}
total += n;
body_so_far = total - header_len;
}
break;
}
}
buf[total] = 0;
char method[16], path[256], body[SDP_BUF_SIZE];
int body_len;
parse_http_request(buf, total, method, path, body, &body_len);
if (strcmp(method, "OPTIONS") == 0)
{
send_http_response(client_fd, 200, "text/plain", "", 0);
close(client_fd);
return;
}
if (strcmp(method, "GET") == 0 && strcmp(path, "/health") == 0)
{
send_http_response(client_fd, 200, "text/plain", "ok", 2);
close(client_fd);
return;
}
if (strcmp(method, "POST") == 0 && strstr(path, "/create"))
{
char stream_path[256];
strncpy(stream_path, path + 1, sizeof(stream_path) - 1);
char *create_pos = strstr(stream_path, "/create");
if (create_pos)
{
*create_pos = 0;
}
pthread_mutex_lock(&sessions_lock);
Session *session = find_session(stream_path);
if (!session)
{
session = create_session_slot(stream_path);
}
pthread_mutex_unlock(&sessions_lock);
if (session)
{
char port_str[16];
snprintf(port_str, sizeof(port_str), "%d", session->rtp_port);
send_http_response(client_fd, 200, "text/plain", port_str, strlen(port_str));
}
else
{
send_http_response(client_fd, 500, "text/plain", "failed", 6);
}
close(client_fd);
return;
}
if (strcmp(method, "GET") == 0 && strncmp(path, "/session/", 9) == 0)
{
const char *check_path = path + 9;
pthread_mutex_lock(&sessions_lock);
Session *s = find_session(check_path);
pthread_mutex_unlock(&sessions_lock);
if (s)
{
send_http_response(client_fd, 200, "text/plain", "ok", 2);
}
else
{
send_http_response(client_fd, 404, "text/plain", "no", 2);
}
close(client_fd);
return;
}
if (strcmp(method, "POST") != 0 || !strstr(path, "/whep"))
{
send_http_response(client_fd, 404, "text/plain", "not found", 9);
close(client_fd);
return;
}
char stream_path[256];
char *whep_pos = strstr(path + 1, "/whep");
int plen = whep_pos - path - 1;
strncpy(stream_path, path + 1, plen);
stream_path[plen] = 0;
char answer[SDP_BUF_SIZE];
pthread_mutex_lock(&sessions_lock);
Session *session = find_session(stream_path);
pthread_mutex_unlock(&sessions_lock);
if (!session)
{
send_http_response(client_fd, 404, "text/plain", "no session", 10);
close(client_fd);
return;
}
int rc = create_viewer_pc(session, body, answer, SDP_BUF_SIZE);
if (rc < 0)
{
send_http_response(client_fd, 500, "text/plain", "failed", 6);
}
else
{
send_http_response(client_fd, 201, "application/sdp", answer, strlen(answer));
}
close(client_fd);
}
static void signal_handler(int sig)
{
running = 0;
}
int main(int argc, char *argv[])
{
int port = 8891;
if (argc > 1)
{
port = atoi(argv[1]);
}
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
rtcInitLogger(RTC_LOG_WARNING, NULL);
memset(sessions, 0, sizeof(sessions));
int server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (server_fd < 0)
{
perror("socket");
return 1;
}
int opt = 1;
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = htons(port);
if (bind(server_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0)
{
perror("bind");
close(server_fd);
return 1;
}
if (listen(server_fd, 16) < 0)
{
perror("listen");
close(server_fd);
return 1;
}
fprintf(stderr, "whip_relay listening on port %d\n", port);
while (running)
{
struct sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
int client_fd = accept(server_fd, (struct sockaddr *)&client_addr, &client_len);
if (client_fd < 0)
{
continue;
}
handle_client(client_fd);
}
close(server_fd);
rtcCleanup();
return 0;
}