Local API (#982)

* Introduce API scelleton

* Raw impl for session

* Simple state endpoint

* Apply _body naming

* Finalize session testing and comment out tons of useless code

* Clean and refactor part1

* Clean and refactor part2

* Clean and refactor part2

* Clean and refactor part2

* Clean and refactor part2

* Refactor middleware

* Refactor middleware

* Clean and refactor part3

* TDD and 2 beers

* TDD and 2 beers

* Complete state endpoints

* You can only set what is already present

* Use only JSON as response

* Use default logger

* Improve auth extraction

* Extend api command with more args

* Adjust API messages
This commit is contained in:
Henry Ruhs
2025-11-17 08:36:31 +01:00
committed by henryruhs
parent f745ddc831
commit 3daf049684
20 changed files with 559 additions and 4 deletions
+2
View File
@@ -35,6 +35,7 @@ jobs:
python-version: '3.12'
- run: python install.py --onnxruntime default --skip-conda
- run: pip install pytest
- run: pip install httpx
- run: pytest
report:
needs: test
@@ -52,6 +53,7 @@ jobs:
- run: pip install coveralls
- run: pip install pytest
- run: pip install pytest-cov
- run: pip install httpx
- run: pytest tests --cov facefusion
- run: coveralls --service github
env:
+1
View File
@@ -31,6 +31,7 @@ commands:
batch-run run the program in batch mode
force-download force automate downloads and exit
benchmark benchmark the program
api start the API server
job-list list jobs by status
job-create create a drafted job
job-submit submit a drafted job to become a queued job
+4
View File
@@ -112,6 +112,10 @@ benchmark_mode =
benchmark_resolutions =
benchmark_cycle_count =
[api]
api_host =
api_port =
[execution]
execution_device_ids =
execution_providers =
View File
+30
View File
@@ -0,0 +1,30 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route
from facefusion.apis.session import create_session
from facefusion.apis.session import create_session_guard
from facefusion.apis.session import destroy_session
from facefusion.apis.session import get_session
from facefusion.apis.session import refresh_session
from facefusion.apis.state import get_state
from facefusion.apis.state import set_state
def create_api() -> Starlette:
session_guard = Middleware(create_session_guard)
routes =\
[
Route('/session', create_session, methods = [ 'POST' ]),
Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/session', refresh_session, methods = [ 'PUT' ]),
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ])
]
api = Starlette(routes = routes)
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
return api
+12
View File
@@ -0,0 +1,12 @@
from facefusion.types import Locals
LOCALS : Locals =\
{
'en':
{
'ok': 'ok',
'something_went_wrong': 'something went wrong',
'invalid_access_token': 'invalid access token',
'invalid_refresh_token': 'invalid refresh token'
}
}
+131
View File
@@ -0,0 +1,131 @@
import os
from typing import Optional
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED, HTTP_426_UPGRADE_REQUIRED
from starlette.types import ASGIApp, Receive, Scope, Send
from facefusion import session_manager, translator
from facefusion.types import Token
async def create_session(request : Request) -> JSONResponse:
body = await request.json()
if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'):
session = session_manager.create_session()
session_manager.set_session(session.get('access_token'), session)
return JSONResponse(
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token')
}, status_code = HTTP_201_CREATED)
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
}, status_code = HTTP_401_UNAUTHORIZED)
async def get_session(request : Request) -> JSONResponse:
access_token = extract_access_token(request.headers)
if access_token:
session = session_manager.get_session(access_token)
if session:
return JSONResponse(
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token'),
'created_at': session.get('created_at').isoformat(),
'expires_at': session.get('expires_at').isoformat()
}, status_code = HTTP_200_OK)
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
}, status_code = HTTP_401_UNAUTHORIZED)
async def refresh_session(request : Request) -> JSONResponse:
body = await request.json()
for access_token, session in session_manager.SESSIONS.items():
if session.get('refresh_token') == body.get('refresh_token'):
session_manager.clear_session(access_token)
session = session_manager.create_session()
session_manager.set_session(session.get('access_token'), session)
return JSONResponse(
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token')
}, status_code = HTTP_200_OK)
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
}, status_code = HTTP_401_UNAUTHORIZED)
async def destroy_session(request : Request) -> JSONResponse:
access_token = extract_access_token(request.headers)
if access_token:
session = session_manager.get_session(access_token)
if session:
session_manager.clear_session(access_token)
return JSONResponse(
{
'message': translator.get('ok', __package__)
}, status_code = HTTP_200_OK)
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
}, status_code = HTTP_401_UNAUTHORIZED)
def create_session_guard(app : ASGIApp) -> ASGIApp:
async def middleware(scope : Scope, receive : Receive, send : Send) -> None:
access_token = extract_access_token(Headers(scope = scope))
if access_token and session_manager.validate_session(access_token):
return await app(scope, receive, send)
if access_token:
session = session_manager.get_session(access_token)
if session:
response = JSONResponse(
{
'message': translator.get('invalid_access_token', __package__)
}, status_code = HTTP_426_UPGRADE_REQUIRED)
return await response(scope, receive, send)
response = JSONResponse(
{
'message': translator.get('invalid_access_token', __package__)
}, status_code = HTTP_401_UNAUTHORIZED)
return await response(scope, receive, send)
return middleware
def extract_access_token(headers : Headers) -> Optional[Token]:
auth_header = headers.get('Authorization')
if auth_header:
auth_prefix, _, access_token = auth_header.partition(' ')
if auth_prefix.lower() == 'bearer' and access_token:
return access_token
return None
+23
View File
@@ -0,0 +1,23 @@
from typing import get_args
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK
from facefusion import state_manager
from facefusion.types import StateKey
async def get_state(request : Request) -> JSONResponse:
return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK)
async def set_state(request : Request) -> JSONResponse:
body = await request.json()
for key, value in body.items():
if key in get_args(StateKey):
state_manager.set_item(key, value)
return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK)
+3
View File
@@ -115,6 +115,9 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
apply_state_item('benchmark_mode', args.get('benchmark_mode'))
apply_state_item('benchmark_resolutions', args.get('benchmark_resolutions'))
apply_state_item('benchmark_cycle_count', args.get('benchmark_cycle_count'))
# api
apply_state_item('api_host', args.get('api_host'))
apply_state_item('api_port', args.get('api_port'))
# memory
apply_state_item('video_memory_strategy', args.get('video_memory_strategy'))
apply_state_item('system_memory_limit', args.get('system_memory_limit'))
+8
View File
@@ -5,7 +5,10 @@ import signal
import sys
from time import time
import uvicorn
from facefusion import benchmarker, cli_helper, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, hash_helper, logger, state_manager, translator, voice_extractor
from facefusion.apis.core import create_api
from facefusion.args import apply_args, collect_job_args, reduce_job_args, reduce_step_args
from facefusion.download import conditional_download_hashes, conditional_download_sources
from facefusion.exit_helper import hard_exit, signal_exit
@@ -55,6 +58,11 @@ def route(args : Args) -> None:
hard_exit(2)
benchmarker.render()
if state_manager.get_item('command') == 'api':
logger.info(translator.get('api_started').format(host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port')), __name__)
uvicorn.run(create_api(), host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port'))
hard_exit(1)
if state_manager.get_item('command') in [ 'job-list', 'job-create', 'job-submit', 'job-submit-all', 'job-delete', 'job-delete-all', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step' ]:
if not job_manager.init_jobs(state_manager.get_item('jobs_path')):
hard_exit(1)
+4 -2
View File
@@ -48,10 +48,9 @@ LOCALES : Locales =\
'no_source_face_detected': 'no source face detected',
'processor_not_loaded': 'processor {processor} could not be loaded',
'processor_not_implemented': 'processor {processor} not implemented correctly',
'ui_layout_not_loaded': 'ui layout {ui_layout} could not be loaded',
'ui_layout_not_implemented': 'ui layout {ui_layout} not implemented correctly',
'stream_not_loaded': 'stream {stream_mode} could not be loaded',
'stream_not_supported': 'stream not supported',
'api_started': 'started API on {host}:{port}',
'job_created': 'job {job_id} created',
'job_not_created': 'job {job_id} not created',
'job_submitted': 'job {job_id} submitted',
@@ -154,6 +153,8 @@ LOCALES : Locales =\
'benchmark_mode': 'choose the benchmark mode',
'benchmark_resolutions': 'choose the resolutions for the benchmarks (choices: {choices}, ...)',
'benchmark_cycle_count': 'specify the amount of cycles per benchmark',
'api_host': 'specify the API host',
'api_port': 'specify the API port',
'execution_device_ids': 'specify the devices used for processing',
'execution_providers': 'inference using different providers (choices: {choices}, ...)',
'execution_thread_count': 'specify the amount of parallel threads while processing',
@@ -165,6 +166,7 @@ LOCALES : Locales =\
'batch_run': 'run the program in batch mode',
'force_download': 'force automate downloads and exit',
'benchmark': 'benchmark the program',
'api': 'start the API server',
'job_id': 'specify the job id',
'job_status': 'specify the job status',
'step_index': 'specify the step index',
+9
View File
@@ -220,6 +220,14 @@ def create_benchmark_program() -> ArgumentParser:
return program
def create_api_program() -> ArgumentParser:
program = ArgumentParser(add_help = False)
group_api = program.add_argument_group('api')
group_api.add_argument('--api-host', help = translator.get('help.api_host'), default = config.get_str_value('api', 'api_host', '127.0.0.1'))
group_api.add_argument('--api-port', help = translator.get('help.api_port'), type = int, default = config.get_int_value('api', 'api_port', '8000'))
return program
def create_execution_program() -> ArgumentParser:
program = ArgumentParser(add_help = False)
available_execution_providers = get_available_execution_providers()
@@ -292,6 +300,7 @@ def create_program() -> ArgumentParser:
sub_program.add_parser('batch-run', help = translator.get('help.batch_run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_pattern_program(), create_target_pattern_program(), create_output_pattern_program(), collect_step_program(), collect_job_program() ], formatter_class = create_help_formatter_large)
sub_program.add_parser('force-download', help = translator.get('help.force_download'), parents = [ create_download_providers_program(), create_download_scope_program(), create_log_level_program() ], formatter_class = create_help_formatter_large)
sub_program.add_parser('benchmark', help = translator.get('help.benchmark'), parents = [ create_temp_path_program(), collect_step_program(), create_benchmark_program(), collect_job_program() ], formatter_class = create_help_formatter_large)
sub_program.add_parser('api', help = translator.get('help.api'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_api_program(), collect_job_program() ], formatter_class = create_help_formatter_large)
# job manager
sub_program.add_parser('job-list', help = translator.get('help.job_list'), parents = [ create_job_status_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large)
sub_program.add_parser('job-create', help = translator.get('help.job_create'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large)
+39
View File
@@ -0,0 +1,39 @@
import secrets
from datetime import datetime, timedelta
from typing import Dict
from typing import Optional
from facefusion.types import Session, Token
SESSIONS : Dict[Token, Session] = {}
def create_session() -> Session:
session : Session =\
{
'access_token': secrets.token_urlsafe(128),
'refresh_token': secrets.token_urlsafe(128),
'created_at': datetime.now(),
'expires_at': datetime.now() + timedelta(minutes = 10)
}
return session
def get_session(access_token : Token) -> Optional[Session]:
return SESSIONS.get(access_token)
def set_session(access_token : Token, session : Session) -> None:
SESSIONS[access_token] = session
def validate_session(access_token : Token) -> bool:
session = get_session(access_token)
return session and datetime.now() <= session.get('expires_at')
def clear_session(access_token : Token) -> None:
if access_token in SESSIONS:
del SESSIONS[access_token]
+10
View File
@@ -1,4 +1,5 @@
from collections import namedtuple
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeAlias, TypedDict
import cv2
@@ -101,6 +102,15 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool]
Content : TypeAlias = Dict[str, Any]
Token : TypeAlias = str
Session = TypedDict('Session',
{
'access_token': Token,
'refresh_token': Token,
'created_at': datetime,
'expires_at': datetime
})
Command : TypeAlias = str
CommandSet : TypeAlias = Dict[str, List[Command]]
+2
View File
@@ -5,3 +5,5 @@ opencv-python==4.12.0.88
psutil==7.1.3
tqdm==4.67.1
scipy==1.16.3
starlette==0.50.0
uvicorn==0.34.0
+154
View File
@@ -0,0 +1,154 @@
import os
from datetime import timedelta
from typing import Iterator
import pytest
from starlette.testclient import TestClient
from facefusion import metadata, session_manager
from facefusion.apis.core import create_api
from facefusion.types import Session
@pytest.fixture(scope = 'module')
def test_client() -> Iterator[TestClient]:
with TestClient(create_api()) as test_client:
yield test_client
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
session_manager.SESSIONS.clear()
def test_create_session(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
assert session_manager.get_session(create_session_body.get('access_token'))
assert create_session_response.status_code == 201
create_session_response = test_client.post('/session', json =
{
'api_key': 'TEST',
'client_version': metadata.get('version')
})
assert create_session_response.status_code == 401
os.environ['FACEFUSION_API_KEY'] = 'TEST'
create_session_response = test_client.post('/session', json =
{
'api_key': 'INVALID',
'client_version': metadata.get('version')
})
assert create_session_response.status_code == 401
os.environ['FACEFUSION_API_KEY'] = 'TEST'
create_session_response = test_client.post('/session', json =
{
'api_key': 'TEST',
'client_version': metadata.get('version')
})
assert create_session_response.status_code == 201
del os.environ['FACEFUSION_API_KEY']
def test_get_session(test_client : TestClient) -> None:
get_session_response = test_client.get('/session')
assert get_session_response.status_code == 401
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
get_session_response = test_client.get('/session', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
assert get_session_response.status_code == 200
access_token = create_session_body.get('access_token')
session : Session = session_manager.get_session(access_token)
session_manager.set_session(access_token,
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token'),
'created_at': session.get('created_at'),
'expires_at': session.get('expires_at') - timedelta(hours = 1)
})
get_session_response = test_client.get('/session', headers =
{
'Authorization': 'Bearer ' + access_token
})
assert get_session_response.status_code == 426
def test_refresh_session(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
refresh_session_response = test_client.put('/session', json =
{
'refresh_token': 'INVALID'
})
assert refresh_session_response.status_code == 401
refresh_session_response = test_client.put('/session', json =
{
'refresh_token': create_session_body.get('refresh_token')
})
refresh_session_body = refresh_session_response.json()
assert session_manager.get_session(create_session_body.get('access_token')) is None
assert session_manager.get_session(refresh_session_body.get('access_token'))
assert refresh_session_response.status_code == 200
refresh_session_response = test_client.put('/session', json =
{
'refresh_token': create_session_body.get('refresh_token')
})
assert refresh_session_response.status_code == 401
def test_destroy_session(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
delete_session_response = test_client.delete('/session', headers =
{
'Authorization': 'Bearer INVALID'
})
assert delete_session_response.status_code == 401
delete_session_response = test_client.delete('/session', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
assert session_manager.get_session(create_session_body.get('access_token')) is None
assert delete_session_response.status_code == 200
+80
View File
@@ -0,0 +1,80 @@
from typing import Iterator
import pytest
from starlette.testclient import TestClient
from facefusion import metadata, session_manager, state_manager
from facefusion.apis.core import create_api
@pytest.fixture(scope = 'module')
def test_client() -> Iterator[TestClient]:
state_manager.init_item('execution_providers', [ 'cpu' ])
with TestClient(create_api()) as test_client:
yield test_client
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
session_manager.SESSIONS.clear()
def test_get_state(test_client : TestClient) -> None:
get_state_response = test_client.get('/state')
assert get_state_response.status_code == 401
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
get_state_response = test_client.get('/state', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
get_state_body = get_state_response.json()
assert get_state_body.get('execution_providers') == [ 'cpu' ]
assert get_state_response.status_code == 200
def test_set_state(test_client : TestClient) -> None:
set_state_response = test_client.put('/state', json =
{
'execution_providers': [ 'cuda' ]
})
assert set_state_response.status_code == 401
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
set_state_response = test_client.put('/state', json =
{
'execution_providers': [ 'cuda' ]
}, headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
set_state_body = set_state_response.json()
assert set_state_body.get('execution_providers') == [ 'cuda' ]
assert set_state_response.status_code == 200
set_state_response = test_client.put('/state', json =
{
'invalid': 'invalid'
}, headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
set_state_body = set_state_response.json()
assert set_state_body.get('invalid') is None
assert set_state_response.status_code == 200
+1 -1
View File
@@ -106,7 +106,7 @@ def test_find_jobs() -> None:
assert 'job-test-find-jobs-1' in find_jobs('drafted')
assert 'job-test-find-jobs-2' in find_jobs('drafted')
assert not find_jobs('queued')
assert find_jobs('queued') == {}
move_job_file('job-test-find-jobs-1', 'queued')
+1 -1
View File
@@ -6,7 +6,7 @@ from facefusion.json import read_json, write_json
def test_read_json() -> None:
_, json_path = tempfile.mkstemp(suffix = '.json')
assert not read_json(json_path)
assert read_json(json_path) is None
write_json(json_path, {})
+45
View File
@@ -0,0 +1,45 @@
import secrets
from datetime import timedelta
from facefusion.session_manager import clear_session, create_session, get_session, set_session, validate_session
def test_get_and_set_session() -> None:
session = create_session()
access_token = secrets.token_urlsafe(128)
set_session(access_token, session)
assert get_session(access_token) == session
def test_validate_session() -> None:
session = create_session()
access_token = secrets.token_urlsafe(128)
set_session(access_token, session)
assert validate_session(access_token) is True
set_session(access_token,
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token'),
'created_at': session.get('created_at'),
'expires_at': session.get('expires_at') - timedelta(hours = 1)
})
assert validate_session(access_token) is False
def test_clear_session() -> None:
session = create_session()
access_token = secrets.token_urlsafe(128)
set_session(access_token, session)
assert validate_session(access_token) is True
clear_session(access_token)
assert validate_session(access_token) is None