add metrics

This commit is contained in:
harisreedhar
2026-02-07 14:13:45 +05:30
committed by henryruhs
parent 2db7ad426a
commit db5ffdc449
5 changed files with 103 additions and 0 deletions
+3
View File
@@ -7,6 +7,7 @@ from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_asset
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
from facefusion.apis.metrics import get_metrics, websocket_metrics
def create_api() -> Starlette:
@@ -23,6 +24,8 @@ def create_api() -> Starlette:
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]),
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
]
+32
View File
@@ -0,0 +1,32 @@
import asyncio
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_404_NOT_FOUND
from starlette.websockets import WebSocket
from facefusion.apis.api_helper import get_sec_websocket_protocol
from facefusion.system import get_metrics_set
async def get_metrics(request : Request) -> Response:
metrics_set = get_metrics_set()
if metrics_set:
return JSONResponse(metrics_set)
return Response(status_code = HTTP_404_NOT_FOUND)
async def websocket_metrics(websocket : WebSocket) -> None:
subprotocol = get_sec_websocket_protocol(websocket.scope)
await websocket.accept(subprotocol = subprotocol)
try:
while True:
metrics_set = get_metrics_set()
await websocket.send_json(metrics_set)
await asyncio.sleep(2)
except Exception:
pass
+9
View File
@@ -0,0 +1,9 @@
from facefusion.execution import detect_execution_devices
from facefusion.types import Metrics
def get_metrics_set() -> Metrics:
return\
{
'execution_devices': detect_execution_devices()
}
+4
View File
@@ -298,6 +298,10 @@ ExecutionDevice = TypedDict('ExecutionDevice',
'temperature' : ExecutionDeviceTemperature,
'utilization' : ExecutionDeviceUtilization
})
Metrics = TypedDict('Metrics',
{
'execution_devices' : List[ExecutionDevice]
})
DownloadProvider = Literal['github', 'huggingface']
DownloadProviderValue = TypedDict('DownloadProviderValue',
+55
View File
@@ -0,0 +1,55 @@
from typing import Iterator
import pytest
from starlette.testclient import TestClient
from facefusion import metadata, session_manager
from facefusion.apis.core import create_api
@pytest.fixture(scope = 'module')
def test_client() -> Iterator[TestClient]:
with TestClient(create_api()) as test_client:
yield test_client
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
session_manager.SESSIONS.clear()
def test_get_metrics(test_client : TestClient) -> None:
get_metrics_response = test_client.get('/metrics')
assert get_metrics_response.status_code == 401
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
get_metrics_response = test_client.get('/metrics', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
get_metrics_body = get_metrics_response.json()
assert get_metrics_body.get('execution_devices')
assert get_metrics_response.status_code == 200
def test_websocket_metrics(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
with test_client.websocket_connect('/metrics', subprotocols =
[
'access_token.' + create_session_body.get('access_token')
]) as websocket:
metrics_set = websocket.receive_json()
assert metrics_set.get('execution_devices')