mirror of
https://github.com/facefusion/facefusion.git
synced 2026-05-12 18:32:18 +02:00
add metrics
This commit is contained in:
@@ -7,6 +7,7 @@ from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_asset
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
from facefusion.apis.metrics import get_metrics, websocket_metrics
|
||||
|
||||
|
||||
def create_api() -> Starlette:
|
||||
@@ -23,6 +24,8 @@ def create_api() -> Starlette:
|
||||
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]),
|
||||
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/metrics', get_metrics, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
WebSocketRoute('/metrics', websocket_metrics, middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import asyncio
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from facefusion.apis.api_helper import get_sec_websocket_protocol
|
||||
from facefusion.system import get_metrics_set
|
||||
|
||||
|
||||
async def get_metrics(request : Request) -> Response:
|
||||
metrics_set = get_metrics_set()
|
||||
|
||||
if metrics_set:
|
||||
return JSONResponse(metrics_set)
|
||||
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
async def websocket_metrics(websocket : WebSocket) -> None:
|
||||
subprotocol = get_sec_websocket_protocol(websocket.scope)
|
||||
await websocket.accept(subprotocol = subprotocol)
|
||||
|
||||
try:
|
||||
while True:
|
||||
metrics_set = get_metrics_set()
|
||||
await websocket.send_json(metrics_set)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from facefusion.execution import detect_execution_devices
|
||||
from facefusion.types import Metrics
|
||||
|
||||
|
||||
def get_metrics_set() -> Metrics:
|
||||
return\
|
||||
{
|
||||
'execution_devices': detect_execution_devices()
|
||||
}
|
||||
@@ -298,6 +298,10 @@ ExecutionDevice = TypedDict('ExecutionDevice',
|
||||
'temperature' : ExecutionDeviceTemperature,
|
||||
'utilization' : ExecutionDeviceUtilization
|
||||
})
|
||||
Metrics = TypedDict('Metrics',
|
||||
{
|
||||
'execution_devices' : List[ExecutionDevice]
|
||||
})
|
||||
|
||||
DownloadProvider = Literal['github', 'huggingface']
|
||||
DownloadProviderValue = TypedDict('DownloadProviderValue',
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, session_manager
|
||||
from facefusion.apis.core import create_api
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
def test_client() -> Iterator[TestClient]:
|
||||
with TestClient(create_api()) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function', autouse = True)
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
|
||||
|
||||
def test_get_metrics(test_client : TestClient) -> None:
|
||||
get_metrics_response = test_client.get('/metrics')
|
||||
|
||||
assert get_metrics_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
get_metrics_response = test_client.get('/metrics', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
get_metrics_body = get_metrics_response.json()
|
||||
|
||||
assert get_metrics_body.get('execution_devices')
|
||||
assert get_metrics_response.status_code == 200
|
||||
|
||||
|
||||
def test_websocket_metrics(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
with test_client.websocket_connect('/metrics', subprotocols =
|
||||
[
|
||||
'access_token.' + create_session_body.get('access_token')
|
||||
]) as websocket:
|
||||
metrics_set = websocket.receive_json()
|
||||
|
||||
assert metrics_set.get('execution_devices')
|
||||
Reference in New Issue
Block a user