mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 13:05:59 +02:00
feat/ping-endpoint (#1001)
* api: add WebSocket /ping endpoint and update session guard to support WebSocket subprotocol auth; add tests (test_api_ping.py) * Initial websocket support using ping * Initial websocket support using ping * Initial websocket support using ping * Combine imports
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import Scope
|
||||
|
||||
|
||||
def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
|
||||
protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol')
|
||||
|
||||
if protocol_header:
|
||||
protocol, _, _ = protocol_header.partition(',')
|
||||
return protocol.strip()
|
||||
|
||||
return None
|
||||
|
||||
+13
-16
@@ -1,28 +1,25 @@
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Route
|
||||
from starlette.routing import Route, WebSocketRoute
|
||||
|
||||
from facefusion.apis.session import create_session
|
||||
from facefusion.apis.session import create_session_guard
|
||||
from facefusion.apis.session import destroy_session
|
||||
from facefusion.apis.session import get_session
|
||||
from facefusion.apis.session import refresh_session
|
||||
from facefusion.apis.state import get_state
|
||||
from facefusion.apis.state import set_state
|
||||
from facefusion.apis.ping import websocket_ping
|
||||
from facefusion.apis.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.state import get_state, set_state
|
||||
|
||||
|
||||
def create_api() -> Starlette:
|
||||
session_guard = Middleware(create_session_guard)
|
||||
routes =\
|
||||
[
|
||||
Route('/session', create_session, methods = [ 'POST' ]),
|
||||
Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/session', refresh_session, methods = [ 'PUT' ]),
|
||||
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ])
|
||||
]
|
||||
[
|
||||
Route('/session', create_session, methods = [ 'POST' ]),
|
||||
Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/session', refresh_session, methods = [ 'PUT' ]),
|
||||
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
api = Starlette(routes = routes)
|
||||
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
|
||||
|
||||
async def websocket_ping(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive()
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
@@ -9,6 +9,7 @@ from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZE
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from facefusion import session_context, session_manager, translator
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.types import Token
|
||||
|
||||
|
||||
@@ -34,7 +35,7 @@ async def create_session(request : Request) -> JSONResponse:
|
||||
|
||||
|
||||
async def get_session(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.headers)
|
||||
access_token = extract_access_token(request.scope)
|
||||
|
||||
if access_token:
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
@@ -77,7 +78,7 @@ async def refresh_session(request : Request) -> JSONResponse:
|
||||
|
||||
|
||||
async def destroy_session(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.headers)
|
||||
access_token = extract_access_token(request.scope)
|
||||
|
||||
if access_token:
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
@@ -98,7 +99,7 @@ async def destroy_session(request : Request) -> JSONResponse:
|
||||
|
||||
def create_session_guard(app : ASGIApp) -> ASGIApp:
|
||||
async def middleware(scope : Scope, receive : Receive, send : Send) -> None:
|
||||
access_token = extract_access_token(Headers(scope = scope))
|
||||
access_token = extract_access_token(scope)
|
||||
|
||||
if access_token:
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
@@ -124,13 +125,23 @@ def create_session_guard(app : ASGIApp) -> ASGIApp:
|
||||
return middleware
|
||||
|
||||
|
||||
def extract_access_token(headers : Headers) -> Optional[Token]:
|
||||
auth_header = headers.get('Authorization')
|
||||
def extract_access_token(scope : Scope) -> Optional[Token]:
|
||||
if scope.get('type') == 'http':
|
||||
auth_header = Headers(scope = scope).get('Authorization')
|
||||
|
||||
if auth_header:
|
||||
auth_prefix, _, access_token = auth_header.partition(' ')
|
||||
if auth_header:
|
||||
auth_prefix, _, access_token = auth_header.partition(' ')
|
||||
|
||||
if auth_prefix.lower() == 'bearer' and access_token:
|
||||
return access_token
|
||||
if auth_prefix.lower() == 'bearer' and access_token:
|
||||
return access_token
|
||||
|
||||
if scope.get('type') == 'websocket':
|
||||
subprotocol = get_sec_websocket_protocol(scope)
|
||||
|
||||
if subprotocol:
|
||||
protocol_prefix, _, access_token = subprotocol.partition('.')
|
||||
|
||||
if protocol_prefix == 'access_token' and access_token:
|
||||
return access_token
|
||||
|
||||
return None
|
||||
|
||||
@@ -7,3 +7,4 @@ tqdm==4.67.3
|
||||
scipy==1.17.1
|
||||
starlette==0.50.0
|
||||
uvicorn==0.34.0
|
||||
websockets==15.0.1
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, session_manager
|
||||
from facefusion.apis.core import create_api
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
def test_client() -> Iterator[TestClient]:
|
||||
with TestClient(create_api()) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function', autouse = True)
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
|
||||
|
||||
def test_ping(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
with test_client.websocket_connect('/ping', subprotocols =
|
||||
[
|
||||
'access_token.' + create_session_body.get('access_token')
|
||||
]) as websocket:
|
||||
assert websocket
|
||||
Reference in New Issue
Block a user