diff --git a/facefusion/apis/endpoints/stream.py b/facefusion/apis/endpoints/stream.py index 446d85fe..667cef50 100644 --- a/facefusion/apis/endpoints/stream.py +++ b/facefusion/apis/endpoints/stream.py @@ -5,11 +5,11 @@ from starlette.websockets import WebSocket from facefusion import rtc_store, session_context, session_manager from facefusion.apis.session_helper import extract_access_token -from facefusion.apis.stream_helper import detect_websocket_stream_mode, handle_image_stream, handle_video_stream +from facefusion.apis.stream_helper import handle_image_stream, handle_video_stream async def websocket_stream(websocket : WebSocket) -> None: - stream_mode = detect_websocket_stream_mode(websocket.scope) + stream_mode = websocket.query_params.get('mode') if stream_mode == 'image': await handle_image_stream(websocket) diff --git a/facefusion/apis/stream_helper.py b/facefusion/apis/stream_helper.py index 5f0bdb40..a497700e 100644 --- a/facefusion/apis/stream_helper.py +++ b/facefusion/apis/stream_helper.py @@ -4,11 +4,10 @@ import os import subprocess from collections import deque from collections.abc import AsyncIterator -from typing import Optional, cast +from typing import Optional import cv2 import numpy -from starlette.types import Scope from starlette.websockets import WebSocket, WebSocketState from facefusion import rtc_store, session_context, session_manager, state_manager @@ -17,7 +16,7 @@ from facefusion.apis.session_helper import extract_access_token from facefusion.common_helper import is_linux, is_macos from facefusion.ffmpeg import spawn_stream from facefusion.streamer import process_vision_frame -from facefusion.types import Resolution, SessionId, VisionFrame, WebSocketStreamMode +from facefusion.types import Resolution, SessionId, VisionFrame def calculate_bitrate(resolution : Resolution) -> int: # TODO : improve the bitrate calculation @@ -30,19 +29,6 @@ def calculate_buffer_size(resolution : Resolution) -> int: return calculate_bitrate(resolution) * 2 -def detect_websocket_stream_mode(scope : Scope) -> Optional[WebSocketStreamMode]: - subprotocol = get_sec_websocket_protocol(scope) - - if subprotocol: - for protocol in subprotocol.split(','): - websocket_stream_mode = cast(WebSocketStreamMode, protocol.strip()) - - if websocket_stream_mode in [ 'image', 'video' ]: - return websocket_stream_mode - - return None - - def read_pipe_buffer(pipe_handle : int, size : int) -> Optional[bytes]: byte_buffer = bytearray() frame_data = os.read(pipe_handle, size - len(byte_buffer)) diff --git a/tests/test_api_helper.py b/tests/test_api_helper.py new file mode 100644 index 00000000..f9ae5e4c --- /dev/null +++ b/tests/test_api_helper.py @@ -0,0 +1,11 @@ +from facefusion.apis.api_helper import get_sec_websocket_protocol + + +def test_get_sec_websocket_protocol() -> None: + scope =\ + { + 'type': 'websocket', + 'headers': [ (b'sec-websocket-protocol', b'access_token.abc') ] + } + + assert get_sec_websocket_protocol(scope) == 'access_token.abc' diff --git a/tests/test_stream_helper.py b/tests/test_stream_helper.py index 95d09714..31b86506 100644 --- a/tests/test_stream_helper.py +++ b/tests/test_stream_helper.py @@ -1,6 +1,6 @@ import os -from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, detect_websocket_stream_mode, read_pipe_buffer +from facefusion.apis.stream_helper import calculate_bitrate, calculate_buffer_size, read_pipe_buffer def test_calculate_bitrate() -> None: @@ -19,24 +19,6 @@ def test_calculate_buffer_size() -> None: assert calculate_buffer_size((3840, 2160)) == 14000 -def test_detect_websocket_stream_mode() -> None: - scope =\ - { - 'type': 'websocket', - 'headers': [ (b'sec-websocket-protocol', b'image') ] - } - - assert detect_websocket_stream_mode(scope) == 'image' - - scope =\ - { - 'type': 'websocket', - 'headers': [ (b'sec-websocket-protocol', b'video') ] - } - - assert detect_websocket_stream_mode(scope) == 'video' - - def test_read_pipe_buffer() -> None: read_pipe, write_pipe = os.pipe() os.write(write_pipe, b'123456')