mass test approaches

This commit is contained in:
henryruhs
2026-03-23 13:33:50 +01:00
parent 44f8f1e83b
commit 021b9a15f5
6 changed files with 1579 additions and 3 deletions
+336
View File
@@ -0,0 +1,336 @@
import asyncio
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional
import numpy
from aiortc import RTCPeerConnection, RTCSessionDescription, AudioStreamTrack, VideoStreamTrack
from av import AudioFrame, VideoFrame
from facefusion import logger
from facefusion.types import VisionFrame
BRIDGE_PORT_START : int = 8893
AUDIO_SAMPLE_RATE : int = 48000
class FramePushTrack(VideoStreamTrack):
kind = 'video'
def __init__(self) -> None:
super().__init__()
self._frame : Optional[VisionFrame] = None
self._lock = threading.Lock()
self._started = False
def push(self, vision_frame : VisionFrame) -> None:
with self._lock:
self._frame = vision_frame
async def recv(self) -> VideoFrame:
pts, time_base = await self.next_timestamp()
with self._lock:
frame_data = self._frame
if frame_data is None:
frame_data = numpy.zeros((240, 320, 3), dtype = numpy.uint8)
if not self._started:
self._started = True
logger.info('aiortc track sending first frame', __name__)
video_frame = VideoFrame.from_ndarray(frame_data, format = 'bgr24')
video_frame.pts = pts
video_frame.time_base = time_base
return video_frame
class AudioPushTrack(AudioStreamTrack):
kind = 'audio'
def __init__(self) -> None:
super().__init__()
self._buffer = bytearray()
self._lock = threading.Lock()
self._pts = 0
self._frame_samples = 960
def push(self, pcm_data : bytes) -> None:
with self._lock:
self._buffer.extend(pcm_data)
if len(self._buffer) > AUDIO_SAMPLE_RATE * 4:
self._buffer = self._buffer[-AUDIO_SAMPLE_RATE * 4:]
async def recv(self) -> AudioFrame:
await asyncio.sleep(self._frame_samples / AUDIO_SAMPLE_RATE)
needed = self._frame_samples * 2 * 2
with self._lock:
if len(self._buffer) >= needed:
chunk = bytes(self._buffer[:needed])
del self._buffer[:needed]
else:
chunk = None
if chunk:
pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1)
else:
pcm = numpy.zeros((1, self._frame_samples * 2), dtype = numpy.int16)
audio_frame = AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo')
audio_frame.sample_rate = AUDIO_SAMPLE_RATE
audio_frame.pts = self._pts
self._pts += self._frame_samples
return audio_frame
class AiortcBridge:
def __init__(self) -> None:
global BRIDGE_PORT_START
self.port = BRIDGE_PORT_START
BRIDGE_PORT_START += 1
self.video_track = FramePushTrack()
self.audio_track = AudioPushTrack()
self.pcs : list = []
self._http_thread : Optional[threading.Thread] = None
self._running = False
self._has_viewer = False
self._loop = None
async def start(self) -> None:
self._running = True
self._loop = asyncio.get_event_loop()
self._http_thread = threading.Thread(target = self._run_http, daemon = True)
self._http_thread.start()
logger.info('aiortc bridge started on port ' + str(self.port), __name__)
async def stop(self) -> None:
self._running = False
for pc in self.pcs:
try:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(pc.close(), loop)
except Exception:
pass
def push_frame(self, vision_frame : VisionFrame) -> None:
self.video_track.push(vision_frame)
def push_audio(self, audio_data : bytes) -> None:
self.audio_track.push(audio_data)
def has_viewer(self) -> bool:
return self._has_viewer
def _handle_whep(self, sdp_offer : str) -> Optional[str]:
if not self._loop:
return None
future = asyncio.run_coroutine_threadsafe(self._create_pc(sdp_offer), self._loop)
try:
return future.result(timeout = 10)
except Exception as exception:
logger.error('whep error: ' + str(exception), __name__)
return None
async def _create_pc(self, sdp_offer : str) -> Optional[str]:
pc = RTCPeerConnection()
self.pcs.append(pc)
pc.addTrack(self.video_track)
pc.addTrack(self.audio_track)
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
self._has_viewer = True
return pc.localDescription.sdp
def _run_http(self) -> None:
bridge = self
class WhepHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.end_headers()
def do_POST(self) -> None:
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
answer = bridge._handle_whep(body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
else:
self.send_response(500)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
server = HTTPServer(('0.0.0.0', self.port), WhepHandler)
server.timeout = 1
while self._running:
server.handle_request()
class WhipAiortcBridge:
def __init__(self) -> None:
global BRIDGE_PORT_START
self.port = BRIDGE_PORT_START
BRIDGE_PORT_START += 1
self.whip_port = BRIDGE_PORT_START
BRIDGE_PORT_START += 1
self._ingest_pc = None
self._relay_track = None
self._viewer_pcs : list = []
self._http_thread : Optional[threading.Thread] = None
self._running = False
self._loop = None
self._ingest_ready = False
async def start(self) -> None:
self._running = True
self._loop = asyncio.get_event_loop()
self._http_thread = threading.Thread(target = self._run_http, daemon = True)
self._http_thread.start()
logger.info('whip-aiortc bridge whip=' + str(self.whip_port) + ' whep=' + str(self.port), __name__)
async def stop(self) -> None:
self._running = False
if self._ingest_pc:
await self._ingest_pc.close()
for pc in self._viewer_pcs:
await pc.close()
def get_whip_url(self) -> str:
return 'http://localhost:' + str(self.whip_port) + '/whip'
def get_whep_url(self) -> str:
return 'http://localhost:' + str(self.port) + '/whep'
def is_ready(self) -> bool:
return self._ingest_ready
def _handle_whip(self, sdp_offer : str) -> Optional[str]:
if not self._loop:
return None
future = asyncio.run_coroutine_threadsafe(self._create_ingest(sdp_offer), self._loop)
try:
return future.result(timeout = 10)
except Exception as exception:
logger.error('whip ingest error: ' + str(exception), __name__)
return None
def _handle_whep(self, sdp_offer : str) -> Optional[str]:
if not self._loop:
return None
future = asyncio.run_coroutine_threadsafe(self._create_viewer(sdp_offer), self._loop)
try:
return future.result(timeout = 10)
except Exception as exception:
logger.error('whep error: ' + str(exception), __name__)
return None
async def _create_ingest(self, sdp_offer : str) -> Optional[str]:
from aiortc import MediaStreamTrack
from aiortc.contrib.media import MediaRelay
pc = RTCPeerConnection()
self._ingest_pc = pc
self._relay = MediaRelay()
@pc.on('track')
def on_track(track : MediaStreamTrack) -> None:
if track.kind == 'video':
self._relay_track = self._relay.subscribe(track)
self._ingest_ready = True
logger.info('whip ingest video track received', __name__)
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return pc.localDescription.sdp
async def _create_viewer(self, sdp_offer : str) -> Optional[str]:
pc = RTCPeerConnection()
self._viewer_pcs.append(pc)
if self._relay_track:
pc.addTrack(self._relay_track)
offer = RTCSessionDescription(sdp = sdp_offer, type = 'offer')
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return pc.localDescription.sdp
def _run_http(self) -> None:
bridge = self
class Handler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS, DELETE')
self.send_header('Access-Control-Allow-Headers', 'Content-Type, Authorization')
self.end_headers()
def do_POST(self) -> None:
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
path = self.path
if '/whip' in path:
answer = bridge._handle_whip(body)
elif '/whep' in path:
answer = bridge._handle_whep(body)
else:
self.send_response(404)
self.end_headers()
return
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Expose-Headers', 'Location')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
else:
self.send_response(500)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
whip_server = HTTPServer(('0.0.0.0', self.whip_port), Handler)
whip_server.timeout = 0.5
whep_server = HTTPServer(('0.0.0.0', self.port), Handler)
whep_server.timeout = 0.5
while self._running:
whip_server.handle_request()
whep_server.handle_request()
+3
View File
@@ -805,6 +805,9 @@ async def websocket_stream_rtc(websocket : WebSocket) -> None:
with lock:
latest_frame_holder[0] = frame
if data[:2] != JPEG_MAGIC:
rtc.send_audio(stream_path, data)
except Exception as exception:
logger.error(str(exception), __name__)
+592
View File
@@ -0,0 +1,592 @@
import ctypes
import ctypes.util
import os
import threading
import time as _time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Dict, List, Optional, TypeAlias
import av
import numpy
from facefusion import logger
RtcLib : TypeAlias = ctypes.CDLL
WHEP_PORT : int = 8892
lib : Optional[RtcLib] = None
sessions : Dict[str, dict] = {}
http_thread : Optional[threading.Thread] = None
running : bool = False
RTC_NEW = 0
RTC_CONNECTING = 1
RTC_CONNECTED = 2
RTC_DISCONNECTED = 3
RTC_FAILED = 4
RTC_CLOSED = 5
RTC_GATHERING_NEW = 0
RTC_GATHERING_INPROGRESS = 1
RTC_GATHERING_COMPLETE = 2
RTC_DIRECTION_SENDONLY = 0
RTC_DIRECTION_RECVONLY = 1
RTC_DIRECTION_SENDRECV = 2
RTC_DIRECTION_INACTIVE = 3
RTC_DIRECTION_UNKNOWN = 4
LOG_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)
DESCRIPTION_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p)
CANDIDATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_void_p)
STATE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
GATHERING_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
TRACK_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
MESSAGE_CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p)
class RtcConfiguration(ctypes.Structure):
_fields_ =\
[
('iceServers', ctypes.POINTER(ctypes.c_char_p)),
('iceServersCount', ctypes.c_int),
('proxyServer', ctypes.c_char_p),
('bindAddress', ctypes.c_char_p),
('certificateType', ctypes.c_int),
('iceTransportPolicy', ctypes.c_int),
('enableIceTcp', ctypes.c_bool),
('enableIceUdpMux', ctypes.c_bool),
('disableAutoNegotiation', ctypes.c_bool),
('forceMediaTransport', ctypes.c_bool),
('portRangeBegin', ctypes.c_ushort),
('portRangeEnd', ctypes.c_ushort),
('mtu', ctypes.c_int),
('maxMessageSize', ctypes.c_int)
]
class RtcPacketizerInit(ctypes.Structure):
_fields_ =\
[
('ssrc', ctypes.c_uint32),
('cname', ctypes.c_char_p),
('payloadType', ctypes.c_uint8),
('clockRate', ctypes.c_uint32),
('sequenceNumber', ctypes.c_uint16),
('timestamp', ctypes.c_uint32),
('maxFragmentSize', ctypes.c_uint16),
('nalSeparator', ctypes.c_int),
('obuPacketization', ctypes.c_int),
('playoutDelayId', ctypes.c_uint8),
('playoutDelayMin', ctypes.c_uint16),
('playoutDelayMax', ctypes.c_uint16),
('colorSpaceId', ctypes.c_uint8),
('colorChromaSitingHorz', ctypes.c_uint8),
('colorChromaSitingVert', ctypes.c_uint8),
('colorRange', ctypes.c_uint8),
('colorPrimaries', ctypes.c_uint8),
('colorTransfer', ctypes.c_uint8),
('colorMatrix', ctypes.c_uint8)
]
def find_library() -> Optional[str]:
lib_path = ctypes.util.find_library('datachannel')
if lib_path:
return lib_path
search_paths =\
[
'/home/henry/local/lib/libdatachannel.so',
'/usr/local/lib/libdatachannel.so',
'/usr/lib/libdatachannel.so',
'/usr/lib/x86_64-linux-gnu/libdatachannel.so'
]
for path in search_paths:
if os.path.isfile(path):
return path
return None
def load_library() -> bool:
global lib
lib_path = find_library()
if not lib_path:
logger.warn('libdatachannel.so not found', __name__)
return False
lib = ctypes.CDLL(lib_path)
setup_prototypes()
lib.rtcInitLogger(4, LOG_CALLBACK_TYPE(0))
logger.info('libdatachannel loaded from ' + lib_path, __name__)
return True
def setup_prototypes() -> None:
lib.rtcInitLogger.argtypes = [ctypes.c_int, LOG_CALLBACK_TYPE]
lib.rtcInitLogger.restype = None
lib.rtcCreatePeerConnection.argtypes = [ctypes.POINTER(RtcConfiguration)]
lib.rtcCreatePeerConnection.restype = ctypes.c_int
lib.rtcDeletePeerConnection.argtypes = [ctypes.c_int]
lib.rtcDeletePeerConnection.restype = ctypes.c_int
lib.rtcSetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p]
lib.rtcSetLocalDescription.restype = ctypes.c_int
lib.rtcSetRemoteDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
lib.rtcSetRemoteDescription.restype = ctypes.c_int
lib.rtcGetLocalDescription.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int]
lib.rtcGetLocalDescription.restype = ctypes.c_int
lib.rtcAddTrack.argtypes = [ctypes.c_int, ctypes.c_char_p]
lib.rtcAddTrack.restype = ctypes.c_int
lib.rtcSetUserPointer.argtypes = [ctypes.c_int, ctypes.c_void_p]
lib.rtcSetUserPointer.restype = None
lib.rtcSetLocalDescriptionCallback.argtypes = [ctypes.c_int, DESCRIPTION_CALLBACK_TYPE]
lib.rtcSetLocalDescriptionCallback.restype = ctypes.c_int
lib.rtcSetLocalCandidateCallback.argtypes = [ctypes.c_int, CANDIDATE_CALLBACK_TYPE]
lib.rtcSetLocalCandidateCallback.restype = ctypes.c_int
lib.rtcSetStateChangeCallback.argtypes = [ctypes.c_int, STATE_CALLBACK_TYPE]
lib.rtcSetStateChangeCallback.restype = ctypes.c_int
lib.rtcSetGatheringStateChangeCallback.argtypes = [ctypes.c_int, GATHERING_CALLBACK_TYPE]
lib.rtcSetGatheringStateChangeCallback.restype = ctypes.c_int
lib.rtcSetTrackCallback.argtypes = [ctypes.c_int, TRACK_CALLBACK_TYPE]
lib.rtcSetTrackCallback.restype = ctypes.c_int
lib.rtcSetMessageCallback.argtypes = [ctypes.c_int, MESSAGE_CALLBACK_TYPE]
lib.rtcSetMessageCallback.restype = ctypes.c_int
lib.rtcSendMessage.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
lib.rtcSendMessage.restype = ctypes.c_int
lib.rtcSetH264Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
lib.rtcSetH264Packetizer.restype = ctypes.c_int
lib.rtcSetVP8Packetizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
lib.rtcSetVP8Packetizer.restype = ctypes.c_int
lib.rtcChainRtcpSrReporter.argtypes = [ctypes.c_int]
lib.rtcChainRtcpSrReporter.restype = ctypes.c_int
lib.rtcChainRtcpNackResponder.argtypes = [ctypes.c_int, ctypes.c_uint]
lib.rtcChainRtcpNackResponder.restype = ctypes.c_int
lib.rtcSetTrackRtpTimestamp.argtypes = [ctypes.c_int, ctypes.c_uint32]
lib.rtcSetTrackRtpTimestamp.restype = ctypes.c_int
lib.rtcSetOpenCallback.argtypes = [ctypes.c_int, ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_void_p)]
lib.rtcSetOpenCallback.restype = ctypes.c_int
lib.rtcIsOpen.argtypes = [ctypes.c_int]
lib.rtcIsOpen.restype = ctypes.c_bool
lib.rtcSetOpusPacketizer.argtypes = [ctypes.c_int, ctypes.POINTER(RtcPacketizerInit)]
lib.rtcSetOpusPacketizer.restype = ctypes.c_int
callback_refs : List = []
def create_peer_connection() -> int:
config = RtcConfiguration()
config.iceServers = None
config.iceServersCount = 0
config.proxyServer = None
config.bindAddress = None
config.certificateType = 0
config.iceTransportPolicy = 0
config.enableIceTcp = False
config.enableIceUdpMux = True
config.disableAutoNegotiation = False
config.forceMediaTransport = True
config.portRangeBegin = 0
config.portRangeEnd = 0
config.mtu = 0
config.maxMessageSize = 0
return lib.rtcCreatePeerConnection(ctypes.byref(config))
next_rtp_port : int = 16000
def create_session(stream_path : str) -> None:
sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': 0, 'rtp_fd': None}
def create_rtp_session(stream_path : str) -> int:
global next_rtp_port
import socket as sock
rtp_port = next_rtp_port
next_rtp_port += 1
rtp_fd = sock.socket(sock.AF_INET, sock.SOCK_DGRAM)
rtp_fd.bind(('127.0.0.1', rtp_port))
rtp_fd.settimeout(1.0)
sessions[stream_path] = {'viewers': [], 'tracks': [], 'rtp_port': rtp_port, 'rtp_fd': rtp_fd}
rtp_thread = threading.Thread(target = run_rtp_forwarder, args = (stream_path,), daemon = True)
rtp_thread.start()
return rtp_port
def run_rtp_forwarder(stream_path : str) -> None:
session = sessions.get(stream_path)
if not session:
return
rtp_fd = session.get('rtp_fd')
while running and session.get('rtp_fd'):
try:
data, addr = rtp_fd.recvfrom(262144)
if not data:
continue
send_to_viewers(stream_path, data)
except Exception:
continue
send_start_time : float = 0
opus_encoder : Optional[av.CodecContext] = None
audio_buffer : bytearray = bytearray()
audio_lock : threading.Lock = threading.Lock()
OPUS_FRAME_SAMPLES : int = 960
def send_to_viewers(stream_path : str, data : bytes) -> None:
global send_start_time
session = sessions.get(stream_path)
if not session:
return
viewers = session.get('viewers')
if not viewers:
return
if send_start_time == 0:
send_start_time = _time.monotonic()
elapsed = _time.monotonic() - send_start_time
timestamp = int(elapsed * 90000) & 0xFFFFFFFF
buf = ctypes.create_string_buffer(data)
data_len = len(data)
for viewer in viewers:
if not viewer.get('connected'):
continue
for track_id in viewer.get('tracks', []):
if not lib.rtcIsOpen(track_id):
continue
lib.rtcSetTrackRtpTimestamp(track_id, timestamp)
lib.rtcSendMessage(track_id, buf, data_len)
def get_opus_encoder() -> av.CodecContext:
global opus_encoder
if not opus_encoder:
opus_encoder = av.CodecContext.create('libopus', 'w')
opus_encoder.sample_rate = 48000
opus_encoder.layout = 'stereo'
opus_encoder.format = av.AudioFormat('s16')
opus_encoder.open()
return opus_encoder
def send_audio(stream_path : str, pcm_data : bytes) -> None:
session = sessions.get(stream_path)
if not session:
return
viewers = session.get('viewers')
if not viewers:
return
with audio_lock:
audio_buffer.extend(pcm_data)
needed = OPUS_FRAME_SAMPLES * 2 * 2
while len(audio_buffer) >= needed:
chunk = bytes(audio_buffer[:needed])
del audio_buffer[:needed]
encoder = get_opus_encoder()
pcm = numpy.frombuffer(chunk, dtype = numpy.int16).reshape(1, -1)
frame = av.AudioFrame.from_ndarray(pcm, format = 's16', layout = 'stereo')
frame.sample_rate = 48000
frame.pts = None
for packet in encoder.encode(frame):
opus_data = bytes(packet)
for viewer in viewers:
if not viewer.get('connected'):
continue
audio_track_id = viewer.get('audio_track')
if not audio_track_id:
continue
if not lib.rtcIsOpen(audio_track_id):
continue
elapsed = _time.monotonic() - send_start_time if send_start_time > 0 else 0
timestamp = int(elapsed * 48000) & 0xFFFFFFFF
buf = ctypes.create_string_buffer(opus_data)
lib.rtcSetTrackRtpTimestamp(audio_track_id, timestamp)
lib.rtcSendMessage(audio_track_id, buf, len(opus_data))
h264_au_buffer : Dict[str, bytes] = {}
def send_vp8_frame(stream_path : str, frame_data : bytes) -> None:
send_h264_frame(stream_path, frame_data)
def send_h264_frame(stream_path : str, frame_data : bytes) -> None:
session = sessions.get(stream_path)
if not session:
return
prev = h264_au_buffer.get(stream_path, b'')
buf = prev + frame_data
au_starts = []
i = 0
while i < len(buf) - 4:
if buf[i] == 0 and buf[i + 1] == 0 and buf[i + 2] == 0 and buf[i + 3] == 1 and i + 4 < len(buf):
nal_type = buf[i + 4] & 0x1f
if nal_type == 7 or nal_type == 5:
au_starts.append(i)
i += 1
if len(au_starts) < 2:
h264_au_buffer[stream_path] = buf
return
for j in range(len(au_starts) - 1):
au = buf[au_starts[j]:au_starts[j + 1]]
for viewer in session.get('viewers', []):
tracks = viewer.get('tracks', [])
if tracks:
lib.rtcSendMessage(tracks[0], au, len(au))
h264_au_buffer[stream_path] = buf[au_starts[-1]:]
def destroy_session(stream_path : str) -> None:
session = sessions.pop(stream_path, None)
if not session:
return
for viewer in session.get('viewers', []):
pc_id = viewer.get('pc')
if pc_id is not None:
lib.rtcDeletePeerConnection(pc_id)
def send_data(stream_path : str, data : bytes) -> None:
session = sessions.get(stream_path)
if not session:
return
for viewer in session.get('viewers', []):
for track_id in viewer.get('tracks', []):
lib.rtcSendMessage(track_id, data, len(data))
def handle_whep_offer(stream_path : str, sdp_offer : str) -> Optional[str]:
session = sessions.get(stream_path)
if not session:
return None
if not lib:
return None
pc = create_peer_connection()
gathering_done = threading.Event()
local_sdp_holder = [None]
def on_description(pc_id, sdp, type_str, user_ptr):
local_sdp_holder[0] = sdp.decode('utf-8') if sdp else None
def on_candidate(pc_id, candidate, mid, user_ptr):
pass
def on_gathering(pc_id, state, user_ptr):
if state == RTC_GATHERING_COMPLETE:
gathering_done.set()
viewer = {'pc': pc, 'tracks': [], 'connected': False}
def on_state(pc_id, state, user_ptr):
if state == RTC_CONNECTED:
viewer['connected'] = True
logger.info('viewer pc connected', __name__)
desc_cb = DESCRIPTION_CALLBACK_TYPE(on_description)
cand_cb = CANDIDATE_CALLBACK_TYPE(on_candidate)
gather_cb = GATHERING_CALLBACK_TYPE(on_gathering)
state_cb = STATE_CALLBACK_TYPE(on_state)
callback_refs.extend([desc_cb, cand_cb, gather_cb, state_cb])
lib.rtcSetLocalDescriptionCallback(pc, desc_cb)
lib.rtcSetLocalCandidateCallback(pc, cand_cb)
lib.rtcSetGatheringStateChangeCallback(pc, gather_cb)
lib.rtcSetStateChangeCallback(pc, state_cb)
video_sdp = b'm=video 9 UDP/TLS/RTP/SAVPF 96\r\na=rtpmap:96 VP8/90000\r\na=sendonly\r\na=mid:0\r\na=rtcp-mux\r\n'
audio_sdp = b'm=audio 9 UDP/TLS/RTP/SAVPF 111\r\na=rtpmap:111 opus/48000/2\r\na=sendonly\r\na=mid:1\r\na=rtcp-mux\r\n'
video_track = lib.rtcAddTrack(pc, video_sdp)
audio_track = lib.rtcAddTrack(pc, audio_sdp)
video_packetizer = RtcPacketizerInit()
video_packetizer.ssrc = 42
video_packetizer.cname = b'video'
video_packetizer.payloadType = 96
video_packetizer.clockRate = 90000
video_packetizer.maxFragmentSize = 1200
lib.rtcSetVP8Packetizer(video_track, ctypes.byref(video_packetizer))
lib.rtcChainRtcpSrReporter(video_track)
lib.rtcChainRtcpNackResponder(video_track, 512)
audio_packetizer = RtcPacketizerInit()
audio_packetizer.ssrc = 43
audio_packetizer.cname = b'audio'
audio_packetizer.payloadType = 111
audio_packetizer.clockRate = 48000
lib.rtcSetOpusPacketizer(audio_track, ctypes.byref(audio_packetizer))
lib.rtcChainRtcpSrReporter(audio_track)
viewer['tracks'] = [video_track]
viewer['audio_track'] = audio_track
session['viewers'].append(viewer)
lib.rtcSetRemoteDescription(pc, sdp_offer.encode('utf-8'), b'offer')
gathering_done.wait(timeout = 3)
buf_size = 16384
buf = ctypes.create_string_buffer(buf_size)
result = lib.rtcGetLocalDescription(pc, buf, buf_size)
if result > 0:
local_sdp = buf.value.decode('utf-8')
elif local_sdp_holder[0]:
local_sdp = local_sdp_holder[0]
else:
session['viewers'].remove(viewer)
return None
return local_sdp
def start() -> None:
global running, http_thread
if not load_library():
return
running = True
http_thread = threading.Thread(target = run_http_server, daemon = True)
http_thread.start()
logger.info('rtc whep server started on port ' + str(WHEP_PORT), __name__)
def stop() -> None:
global running
running = False
for stream_path in list(sessions.keys()):
destroy_session(stream_path)
def run_http_server() -> None:
class WhepHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def send_cors_headers(self) -> None:
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_cors_headers()
self.end_headers()
def do_POST(self) -> None:
path = self.path
if not path.endswith('/whep'):
self.send_response(404)
self.send_cors_headers()
self.end_headers()
return
stream_path = path[1:].rsplit('/whep', 1)[0]
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
answer = handle_whep_offer(stream_path, body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_cors_headers()
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
return
self.send_response(404)
self.send_cors_headers()
self.end_headers()
server = HTTPServer(('0.0.0.0', WHEP_PORT), WhepHandler)
server.timeout = 1
while running:
server.handle_request()
+546
View File
@@ -0,0 +1,546 @@
import binascii
import hashlib
import os
import socket
import struct
import threading
import time
from typing import Dict, Optional, Tuple, TypeAlias
import pylibsrtp
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from OpenSSL import SSL
from facefusion import logger
SrtpSession : TypeAlias = pylibsrtp.Session
SrtpPolicy : TypeAlias = pylibsrtp.Policy
WHIP_PORT : int = 8890
ICE_UFRAG_LENGTH : int = 4
ICE_PWD_LENGTH : int = 22
RTP_HEADER_SIZE : int = 12
SRTP_PROFILES =\
[
{
'name': b'SRTP_AES128_CM_SHA1_80',
'libsrtp': SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80,
'key_len': 16,
'salt_len': 14
}
]
sessions : Dict[str, dict] = {}
server_cert = None
server_key = None
server_fingerprint : str = ''
udp_socket : Optional[socket.socket] = None
http_thread : Optional[threading.Thread] = None
udp_thread : Optional[threading.Thread] = None
running : bool = False
def generate_credentials() -> None:
global server_cert, server_key, server_fingerprint
server_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
name = x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, binascii.hexlify(os.urandom(16)).decode())])
import datetime
now = datetime.datetime.now(tz = datetime.timezone.utc)
builder = x509.CertificateBuilder().subject_name(name).issuer_name(name).public_key(server_key.public_key()).serial_number(x509.random_serial_number()).not_valid_before(now - datetime.timedelta(days = 1)).not_valid_after(now + datetime.timedelta(days = 30))
server_cert = builder.sign(server_key, hashes.SHA256(), default_backend())
fp = server_cert.fingerprint(hashes.SHA256()).hex().upper()
server_fingerprint = ':'.join(fp[i:i + 2] for i in range(0, len(fp), 2))
def generate_ice_credentials() -> Tuple[str, str]:
ufrag = binascii.hexlify(os.urandom(ICE_UFRAG_LENGTH)).decode()
pwd = binascii.hexlify(os.urandom(ICE_PWD_LENGTH)).decode()[:ICE_PWD_LENGTH]
return ufrag, pwd
def parse_sdp_offer(sdp : str) -> dict:
result = {'ice_ufrag': '', 'ice_pwd': '', 'fingerprint': '', 'setup': '', 'media': [], 'candidates': []}
current_media = None
for line in sdp.splitlines():
line = line.strip()
if line.startswith('a=ice-ufrag:'):
result['ice_ufrag'] = line.split(':', 1)[1]
if line.startswith('a=ice-pwd:'):
result['ice_pwd'] = line.split(':', 1)[1]
if line.startswith('a=fingerprint:'):
result['fingerprint'] = line.split(' ', 1)[1] if ' ' in line else ''
if line.startswith('a=setup:'):
result['setup'] = line.split(':', 1)[1]
if line.startswith('a=candidate:'):
result['candidates'].append(line[12:])
if line.startswith('m='):
parts = line[2:].split()
current_media = {'kind': parts[0], 'port': int(parts[1]), 'profile': parts[2], 'formats': parts[3:], 'codec_lines': [], 'mid': None}
result['media'].append(current_media)
if current_media:
if line.startswith('a=rtpmap:') or line.startswith('a=fmtp:') or line.startswith('a=rtcp-fb:'):
current_media['codec_lines'].append(line)
if line.startswith('a=mid:'):
current_media['mid'] = line.split(':', 1)[1]
if line.startswith('a=extmap:'):
current_media['codec_lines'].append(line)
return result
def build_sdp_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int) -> str:
lines = []
lines.append('v=0')
lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1')
lines.append('s=-')
lines.append('t=0 0')
mids = []
for i, media in enumerate(offer.get('media', [])):
mids.append(str(i))
if mids:
lines.append('a=group:BUNDLE ' + ' '.join(mids))
lines.append('a=ice-lite')
for i, media in enumerate(offer.get('media', [])):
kind = media.get('kind')
formats = media.get('formats', [])
profile = media.get('profile', 'UDP/TLS/RTP/SAVPF')
mid = media.get('mid', str(i))
lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats))
lines.append('c=IN IP4 127.0.0.1')
lines.append('a=rtcp:9 IN IP4 0.0.0.0')
lines.append('a=ice-ufrag:' + local_ufrag)
lines.append('a=ice-pwd:' + local_pwd)
lines.append('a=ice-options:ice2')
lines.append('a=fingerprint:sha-256 ' + server_fingerprint)
lines.append('a=setup:passive')
lines.append('a=mid:' + mid)
lines.append('a=rtcp-mux')
lines.append('a=recvonly')
for codec_line in media.get('codec_lines', []):
lines.append(codec_line)
lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host')
lines.append('a=end-of-candidates')
return '\r\n'.join(lines) + '\r\n'
def build_whep_answer(offer : dict, local_ufrag : str, local_pwd : str, local_port : int, ingest_offer : dict) -> str:
lines = []
lines.append('v=0')
lines.append('o=- ' + str(int(time.time())) + ' 1 IN IP4 127.0.0.1')
lines.append('s=-')
lines.append('t=0 0')
mids = []
for i, media in enumerate(offer.get('media', [])):
mid = media.get('mid', str(i))
mids.append(mid)
if mids:
lines.append('a=group:BUNDLE ' + ' '.join(mids))
lines.append('a=ice-lite')
for i, media in enumerate(offer.get('media', [])):
kind = media.get('kind')
formats = media.get('formats', [])
profile = media.get('profile', 'UDP/TLS/RTP/SAVPF')
mid = media.get('mid', str(i))
lines.append('m=' + kind + ' 9 ' + profile + ' ' + ' '.join(formats))
lines.append('c=IN IP4 127.0.0.1')
lines.append('a=rtcp:9 IN IP4 0.0.0.0')
lines.append('a=ice-ufrag:' + local_ufrag)
lines.append('a=ice-pwd:' + local_pwd)
lines.append('a=ice-options:ice2')
lines.append('a=fingerprint:sha-256 ' + server_fingerprint)
lines.append('a=setup:passive')
lines.append('a=mid:' + mid)
lines.append('a=rtcp-mux')
lines.append('a=sendonly')
for codec_line in media.get('codec_lines', []):
lines.append(codec_line)
lines.append('a=candidate:1 1 udp 2130706431 127.0.0.1 ' + str(local_port) + ' typ host')
lines.append('a=end-of-candidates')
return '\r\n'.join(lines) + '\r\n'
def create_ssl_context() -> SSL.Context:
ctx = SSL.Context(SSL.DTLS_METHOD)
ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: True)
ctx.use_certificate(server_cert)
ctx.use_privatekey(server_key)
ctx.set_cipher_list(b'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-SHA')
ctx.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80')
return ctx
def create_session(stream_path : str) -> None:
ufrag, pwd = generate_ice_credentials()
sessions[stream_path] = {
'ice_ufrag': ufrag,
'ice_pwd': pwd,
'ingest': None,
'viewers': [],
'ingest_offer': None,
'rx_srtp': None,
'tx_sessions': []
}
def destroy_session(stream_path : str) -> None:
sessions.pop(stream_path, None)
def handle_whip(stream_path : str, sdp_offer : str) -> Optional[str]:
session = sessions.get(stream_path)
if not session:
return None
offer = parse_sdp_offer(sdp_offer)
session['ingest_offer'] = offer
local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT
answer = build_sdp_answer(offer, session.get('ice_ufrag'), session.get('ice_pwd'), local_port)
return answer
def handle_whep(stream_path : str, sdp_offer : str) -> Optional[str]:
session = sessions.get(stream_path)
if not session:
return None
offer = parse_sdp_offer(sdp_offer)
viewer_ufrag, viewer_pwd = generate_ice_credentials()
local_port = udp_socket.getsockname()[1] if udp_socket else WHIP_PORT
ingest_offer = session.get('ingest_offer', offer)
answer = build_whep_answer(offer, viewer_ufrag, viewer_pwd, local_port, ingest_offer)
session['viewers'].append({'offer': offer, 'ice_ufrag': viewer_ufrag, 'ice_pwd': viewer_pwd})
return answer
def start() -> None:
global running, udp_socket, http_thread
generate_credentials()
running = True
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
udp_socket.bind(('0.0.0.0', WHIP_PORT))
udp_socket.settimeout(1.0)
udp_thread_instance = threading.Thread(target = run_udp_loop, daemon = True)
udp_thread_instance.start()
http_thread = threading.Thread(target = run_http_server, daemon = True)
http_thread.start()
logger.info('webrtc sfu started on port ' + str(WHIP_PORT), __name__)
def stop() -> None:
global running, udp_socket
running = False
if udp_socket:
udp_socket.close()
udp_socket = None
dtls_connections : Dict[tuple, dict] = {}
def run_udp_loop() -> None:
while running:
try:
data, addr = udp_socket.recvfrom(2048)
if not data:
continue
first_byte = data[0]
if first_byte == 0 or first_byte == 1:
handle_stun(data, addr)
if first_byte > 19 and first_byte < 64:
handle_dtls(data, addr)
if first_byte > 127 and first_byte < 192:
handle_srtp(data, addr)
except socket.timeout:
continue
except Exception:
if running:
continue
def handle_dtls(data : bytes, addr : tuple) -> None:
conn = dtls_connections.get(addr)
if not conn:
ctx = create_ssl_context()
ssl_conn = SSL.Connection(ctx)
ssl_conn.set_accept_state()
conn = {'ssl': ssl_conn, 'encrypted': False, 'rx_srtp': None, 'tx_srtp': None}
dtls_connections[addr] = conn
ssl_conn = conn.get('ssl')
ssl_conn.bio_write(data)
try:
if not conn.get('encrypted'):
try:
ssl_conn.do_handshake()
conn['encrypted'] = True
setup_srtp_session(conn)
logger.info('dtls handshake complete with ' + str(addr), __name__)
except SSL.WantReadError:
pass
else:
try:
ssl_conn.recv(1500)
except SSL.ZeroReturnError:
pass
except SSL.Error:
pass
except Exception:
pass
flush_dtls(ssl_conn, addr)
def flush_dtls(ssl_conn : SSL.Connection, addr : tuple) -> None:
try:
outdata = ssl_conn.bio_read(1500)
if outdata:
udp_socket.sendto(outdata, addr)
except SSL.Error:
pass
def setup_srtp_session(conn : dict) -> None:
ssl_conn = conn.get('ssl')
ssl_conn.get_selected_srtp_profile()
key_len = 16
salt_len = 14
view = ssl_conn.export_keying_material(b'EXTRACTOR-dtls_srtp', 2 * (key_len + salt_len))
server_key = view[key_len:2 * key_len] + view[2 * key_len + salt_len:]
client_key = view[:key_len] + view[2 * key_len:2 * key_len + salt_len]
rx_policy = SrtpPolicy(key = client_key, ssrc_type = SrtpPolicy.SSRC_ANY_INBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80)
rx_policy.allow_repeat_tx = True
rx_policy.window_size = 1024
conn['rx_srtp'] = SrtpSession(rx_policy)
tx_policy = SrtpPolicy(key = server_key, ssrc_type = SrtpPolicy.SSRC_ANY_OUTBOUND, srtp_profile = SrtpPolicy.SRTP_PROFILE_AES128_CM_SHA1_80)
tx_policy.allow_repeat_tx = True
tx_policy.window_size = 1024
conn['tx_srtp'] = SrtpSession(tx_policy)
def is_rtcp(data : bytes) -> bool:
if len(data) < 2:
return False
pt = data[1] & 0x7F
return 64 <= pt <= 95
def handle_srtp(data : bytes, addr : tuple) -> None:
conn = dtls_connections.get(addr)
if not conn or not conn.get('rx_srtp'):
return
try:
if is_rtcp(data):
plain = conn.get('rx_srtp').unprotect_rtcp(data)
else:
plain = conn.get('rx_srtp').unprotect(data)
forward_rtp(plain, addr)
except Exception:
pass
def forward_rtp(data : bytes, source_addr : tuple) -> None:
for other_addr, conn in dtls_connections.items():
if other_addr == source_addr:
continue
if not conn.get('tx_srtp'):
continue
try:
if is_rtcp(data):
encrypted = conn.get('tx_srtp').protect_rtcp(data)
else:
encrypted = conn.get('tx_srtp').protect(data)
udp_socket.sendto(encrypted, other_addr)
except Exception:
pass
def handle_stun(data : bytes, addr : tuple) -> None:
if len(data) < 20:
return
msg_type = struct.unpack('!H', data[0:2])[0]
if msg_type != 0x0001:
return
msg_length = struct.unpack('!H', data[2:4])[0]
transaction_id = data[8:20]
username = None
offset = 20
while offset < 20 + msg_length:
if offset + 4 > len(data):
break
attr_type = struct.unpack('!H', data[offset:offset + 2])[0]
attr_length = struct.unpack('!H', data[offset + 2:offset + 4])[0]
attr_value = data[offset + 4:offset + 4 + attr_length]
if attr_type == 0x0006:
username = attr_value.decode('utf-8', errors = 'ignore')
padded = attr_length + (4 - attr_length % 4) % 4
offset += 4 + padded
if not username:
return
local_ufrag = username.split(':')[0] if ':' in username else username
session_pwd = None
for session in sessions.values():
if session.get('ice_ufrag') == local_ufrag:
session_pwd = session.get('ice_pwd')
break
for viewer in session.get('viewers', []):
if viewer.get('ice_ufrag') == local_ufrag:
session_pwd = viewer.get('ice_pwd')
break
if session_pwd:
break
if not session_pwd:
return
response = build_stun_response(transaction_id, addr, session_pwd)
udp_socket.sendto(response, addr)
def build_stun_response(transaction_id : bytes, addr : tuple, password : str) -> bytes:
import hmac
import zlib
magic_cookie = 0x2112A442
magic_bytes = struct.pack('!I', magic_cookie)
xport = addr[1] ^ (magic_cookie >> 16)
ip_int = struct.unpack('!I', socket.inet_aton(addr[0]))[0]
xip = struct.pack('!I', ip_int ^ magic_cookie)
xor_addr_value = struct.pack('!BBH', 0, 0x01, xport) + xip
xor_addr_attr = struct.pack('!HH', 0x0020, len(xor_addr_value)) + xor_addr_value
attrs_before_integrity = xor_addr_attr
integrity_dummy_len = len(attrs_before_integrity) + 4 + 20
header_for_hmac = struct.pack('!HH', 0x0101, integrity_dummy_len) + magic_bytes + transaction_id
key = password.encode('utf-8')
integrity = hmac.new(key, header_for_hmac + attrs_before_integrity, hashlib.sha1).digest()
integrity_attr = struct.pack('!HH', 0x0008, 20) + integrity
attrs_before_fp = attrs_before_integrity + integrity_attr
fp_dummy_len = len(attrs_before_fp) + 4 + 4
header_for_fp = struct.pack('!HH', 0x0101, fp_dummy_len) + magic_bytes + transaction_id
crc = zlib.crc32(header_for_fp + attrs_before_fp) ^ 0x5354554E
fingerprint_attr = struct.pack('!HHI', 0x8028, 4, crc & 0xFFFFFFFF)
all_attrs = attrs_before_integrity + integrity_attr + fingerprint_attr
header = struct.pack('!HH', 0x0101, len(all_attrs)) + magic_bytes + transaction_id
return header + all_attrs
def run_http_server() -> None:
from http.server import HTTPServer, BaseHTTPRequestHandler
class WhipWhepHandler(BaseHTTPRequestHandler):
def log_message(self, format, *args) -> None:
pass
def do_POST(self) -> None:
path = self.path
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length).decode('utf-8') if content_length else ''
if path.endswith('/whip'):
stream_path = path[1:].rsplit('/whip', 1)[0]
answer = handle_whip(stream_path, body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
return
self.send_response(404)
self.end_headers()
return
if path.endswith('/whep'):
stream_path = path[1:].rsplit('/whep', 1)[0]
answer = handle_whep(stream_path, body)
if answer:
self.send_response(201)
self.send_header('Content-Type', 'application/sdp')
self.send_header('Location', path)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
self.wfile.write(answer.encode('utf-8'))
return
self.send_response(404)
self.end_headers()
return
self.send_response(404)
self.end_headers()
def do_OPTIONS(self) -> None:
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, DELETE, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.end_headers()
server = HTTPServer(('0.0.0.0', WHIP_PORT), WhipWhepHandler)
server.timeout = 1
while running:
server.handle_request()
+99
View File
@@ -0,0 +1,99 @@
import os
import shutil
import subprocess
import time
from typing import Optional
import httpx
from facefusion import logger
RELAY_PORT : int = 8891
RELAY_BINARY : str = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'tools', 'whip_relay')
RELAY_PROCESS : Optional[subprocess.Popen[bytes]] = None
def get_whip_url(stream_path : str) -> str:
return 'http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/whip'
def get_whep_url(stream_path : str) -> str:
return 'http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/whep'
def resolve_binary() -> str:
relay_path = shutil.which('whip_relay')
if relay_path:
return relay_path
if os.path.isfile(RELAY_BINARY):
return RELAY_BINARY
return RELAY_BINARY
def start() -> None:
global RELAY_PROCESS
subprocess.run([ 'fuser', '-k', str(RELAY_PORT) + '/tcp' ], stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL)
time.sleep(0.5)
relay_binary = resolve_binary()
if not os.path.isfile(relay_binary):
logger.warn('whip_relay binary not found at ' + relay_binary + ', skipping', __name__)
return
env = os.environ.copy()
env['LD_LIBRARY_PATH'] = '/home/henry/local/lib:' + env.get('LD_LIBRARY_PATH', '')
RELAY_PROCESS = subprocess.Popen(
[ relay_binary, str(RELAY_PORT) ],
env = env,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE
)
logger.info('whip relay started on port ' + str(RELAY_PORT), __name__)
def stop() -> None:
global RELAY_PROCESS
if RELAY_PROCESS:
RELAY_PROCESS.terminate()
RELAY_PROCESS.wait()
RELAY_PROCESS = None
def wait_for_ready() -> bool:
for _ in range(10):
try:
response = httpx.get('http://localhost:' + str(RELAY_PORT) + '/health', timeout = 1)
if response.status_code == 200:
return True
except Exception:
pass
time.sleep(0.5)
return False
def is_session_ready(stream_path : str) -> bool:
try:
response = httpx.get('http://localhost:' + str(RELAY_PORT) + '/session/' + stream_path, timeout = 1)
if response.status_code == 200:
return True
except Exception:
pass
return False
def create_session(stream_path : str) -> int:
try:
response = httpx.post('http://localhost:' + str(RELAY_PORT) + '/' + stream_path + '/create', timeout = 5)
if response.status_code == 200:
return int(response.text)
except Exception:
pass
return 0
+3 -3
View File
@@ -72,11 +72,11 @@
.timeline { display: none; align-items: stretch; padding: 0; background: #0e0e14; border-top: 1px solid #1e1e2e; border-bottom: 1px solid #1e1e2e; flex-shrink: 0; }
.timeline.visible { display: flex; }
.timeline .transport { display: flex; align-items: center; gap: 2px; padding: 0 0.4rem; background: #12121a; border-right: 1px solid #1e1e2e; }
.timeline .transport-btn { width: 28px; height: 28px; border: none; border-radius: 6px; cursor: pointer; display: flex; align-items: center; justify-content: center; background: transparent; color: #888; transition: all 0.15s; }
.timeline .transport-btn { width: 36px; height: 36px; border: none; border-radius: 8px; cursor: pointer; display: flex; align-items: center; justify-content: center; background: transparent; color: #888; transition: all 0.15s; }
.timeline .transport-btn:hover { background: #1e1e2e; color: #fff; }
.timeline .transport-btn:disabled { opacity: 0.25; cursor: not-allowed; }
.timeline .transport-btn.active { color: #00b894; }
.timeline .transport-btn svg { width: 14px; height: 14px; fill: currentColor; }
.timeline .transport-btn.active { color: #888; }
.timeline .transport-btn svg { width: 18px; height: 18px; fill: currentColor; }
.timeline .time { font-size: 0.75rem; color: #888; font-family: monospace; min-width: 60px; display: flex; align-items: center; justify-content: center; padding: 0 0.6rem; background: #12121a; border-right: 1px solid #1e1e2e; }
.timeline .time:last-child { border-right: none; border-left: 1px solid #1e1e2e; }
.timeline .track { flex: 1; position: relative; height: 2em; cursor: pointer; background: repeating-linear-gradient(90deg, transparent, transparent 59px, #1a1a25 59px, #1a1a25 60px); }