diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 9ecd5744..160ecefb 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -7,6 +7,7 @@ from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_asset from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state +from facefusion.apis.metrics import get_metrics, websocket_metrics def create_api() -> Starlette: @@ -23,6 +24,8 @@ def create_api() -> Starlette: Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]), Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]), Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]), + Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]), + WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]) ] diff --git a/facefusion/apis/metrics.py b/facefusion/apis/metrics.py new file mode 100644 index 00000000..6d8c5e8a --- /dev/null +++ b/facefusion/apis/metrics.py @@ -0,0 +1,32 @@ +import asyncio + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.status import HTTP_404_NOT_FOUND +from starlette.websockets import WebSocket + +from facefusion.apis.api_helper import get_sec_websocket_protocol +from facefusion.system import get_metrics_set + + +async def get_metrics(request : Request) -> Response: + metrics_set = get_metrics_set() + + if metrics_set: + return JSONResponse(metrics_set) + + return Response(status_code = HTTP_404_NOT_FOUND) + + +async def websocket_metrics(websocket : WebSocket) -> None: + subprotocol = get_sec_websocket_protocol(websocket.scope) + await websocket.accept(subprotocol = subprotocol) + + try: + while True: + metrics_set = get_metrics_set() + await websocket.send_json(metrics_set) + await asyncio.sleep(2) + + except Exception: + pass diff --git a/facefusion/system.py b/facefusion/system.py new file mode 100644 index 00000000..89286f1d --- /dev/null +++ b/facefusion/system.py @@ -0,0 +1,9 @@ +from facefusion.execution import detect_execution_devices +from facefusion.types import Metrics + + +def get_metrics_set() -> Metrics: + return\ + { + 'execution_devices': detect_execution_devices() + } diff --git a/facefusion/types.py b/facefusion/types.py index 3eb22a2f..0576f443 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -298,6 +298,10 @@ ExecutionDevice = TypedDict('ExecutionDevice', 'temperature' : ExecutionDeviceTemperature, 'utilization' : ExecutionDeviceUtilization }) +Metrics = TypedDict('Metrics', +{ + 'execution_devices' : List[ExecutionDevice] +}) DownloadProvider = Literal['github', 'huggingface'] DownloadProviderValue = TypedDict('DownloadProviderValue', diff --git a/tests/test_api_metrics.py b/tests/test_api_metrics.py new file mode 100644 index 00000000..1083dc10 --- /dev/null +++ b/tests/test_api_metrics.py @@ -0,0 +1,55 @@ +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from facefusion import metadata, session_manager +from facefusion.apis.core import create_api + + +@pytest.fixture(scope = 'module') +def test_client() -> Iterator[TestClient]: + with TestClient(create_api()) as test_client: + yield test_client + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + session_manager.SESSIONS.clear() + + +def test_get_metrics(test_client : TestClient) -> None: + get_metrics_response = test_client.get('/metrics') + + assert get_metrics_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + get_metrics_response = test_client.get('/metrics', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + get_metrics_body = get_metrics_response.json() + + assert get_metrics_body.get('execution_devices') + assert get_metrics_response.status_code == 200 + + +def test_websocket_metrics(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + with test_client.websocket_connect('/metrics', subprotocols = + [ + 'access_token.' + create_session_body.get('access_token') + ]) as websocket: + metrics_set = websocket.receive_json() + + assert metrics_set.get('execution_devices') \ No newline at end of file