feat/ping-endpoint (#1001)

* api: add WebSocket /ping endpoint and update session guard to support WebSocket subprotocol auth; add tests (test_api_ping.py)

* Initial websocket support using ping

* Initial websocket support using ping

* Initial websocket support using ping

* Combine imports
This commit is contained in:
Henry Ruhs
2025-12-11 12:00:59 +01:00
committed by henryruhs
parent 127958d581
commit 3d2b0c222c
6 changed files with 97 additions and 25 deletions
+15
View File
@@ -0,0 +1,15 @@
from typing import Optional
from starlette.datastructures import Headers
from starlette.types import Scope
def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol')
if protocol_header:
protocol, _, _ = protocol_header.partition(',')
return protocol.strip()
return None
+13 -16
View File
@@ -1,28 +1,25 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route
from starlette.routing import Route, WebSocketRoute
from facefusion.apis.session import create_session
from facefusion.apis.session import create_session_guard
from facefusion.apis.session import destroy_session
from facefusion.apis.session import get_session
from facefusion.apis.session import refresh_session
from facefusion.apis.state import get_state
from facefusion.apis.state import set_state
from facefusion.apis.ping import websocket_ping
from facefusion.apis.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
from facefusion.apis.state import get_state, set_state
def create_api() -> Starlette:
session_guard = Middleware(create_session_guard)
routes =\
[
Route('/session', create_session, methods = [ 'POST' ]),
Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/session', refresh_session, methods = [ 'PUT' ]),
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ])
]
[
Route('/session', create_session, methods = [ 'POST' ]),
Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/session', refresh_session, methods = [ 'PUT' ]),
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]),
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
]
api = Starlette(routes = routes)
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
+16
View File
@@ -0,0 +1,16 @@
from starlette.websockets import WebSocket
from facefusion.apis.api_helper import get_sec_websocket_protocol
async def websocket_ping(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
await websocket.accept(subprotocol = subprotocol)
try:
while True:
await websocket.receive()
except Exception:
pass
+20 -9
View File
@@ -9,6 +9,7 @@ from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZE
from starlette.types import ASGIApp, Receive, Scope, Send
from facefusion import session_context, session_manager, translator
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.types import Token
@@ -34,7 +35,7 @@ async def create_session(request : Request) -> JSONResponse:
async def get_session(request : Request) -> JSONResponse:
access_token = extract_access_token(request.headers)
access_token = extract_access_token(request.scope)
if access_token:
session_id = session_manager.find_session_id(access_token)
@@ -77,7 +78,7 @@ async def refresh_session(request : Request) -> JSONResponse:
async def destroy_session(request : Request) -> JSONResponse:
access_token = extract_access_token(request.headers)
access_token = extract_access_token(request.scope)
if access_token:
session_id = session_manager.find_session_id(access_token)
@@ -98,7 +99,7 @@ async def destroy_session(request : Request) -> JSONResponse:
def create_session_guard(app : ASGIApp) -> ASGIApp:
async def middleware(scope : Scope, receive : Receive, send : Send) -> None:
access_token = extract_access_token(Headers(scope = scope))
access_token = extract_access_token(scope)
if access_token:
session_id = session_manager.find_session_id(access_token)
@@ -124,13 +125,23 @@ def create_session_guard(app : ASGIApp) -> ASGIApp:
return middleware
def extract_access_token(headers : Headers) -> Optional[Token]:
auth_header = headers.get('Authorization')
def extract_access_token(scope : Scope) -> Optional[Token]:
if scope.get('type') == 'http':
auth_header = Headers(scope = scope).get('Authorization')
if auth_header:
auth_prefix, _, access_token = auth_header.partition(' ')
if auth_header:
auth_prefix, _, access_token = auth_header.partition(' ')
if auth_prefix.lower() == 'bearer' and access_token:
return access_token
if auth_prefix.lower() == 'bearer' and access_token:
return access_token
if scope.get('type') == 'websocket':
subprotocol = get_sec_websocket_protocol(scope)
if subprotocol:
protocol_prefix, _, access_token = subprotocol.partition('.')
if protocol_prefix == 'access_token' and access_token:
return access_token
return None
+1
View File
@@ -7,3 +7,4 @@ tqdm==4.67.3
scipy==1.17.1
starlette==0.50.0
uvicorn==0.34.0
websockets==15.0.1
+32
View File
@@ -0,0 +1,32 @@
from typing import Iterator
import pytest
from starlette.testclient import TestClient
from facefusion import metadata, session_manager
from facefusion.apis.core import create_api
@pytest.fixture(scope = 'module')
def test_client() -> Iterator[TestClient]:
with TestClient(create_api()) as test_client:
yield test_client
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
session_manager.SESSIONS.clear()
def test_ping(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
with test_client.websocket_connect('/ping', subprotocols =
[
'access_token.' + create_session_body.get('access_token')
]) as websocket:
assert websocket