mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 04:55:57 +02:00
Local API (#982)
* Introduce API scelleton * Raw impl for session * Simple state endpoint * Apply _body naming * Finalize session testing and comment out tons of useless code * Clean and refactor part1 * Clean and refactor part2 * Clean and refactor part2 * Clean and refactor part2 * Clean and refactor part2 * Refactor middleware * Refactor middleware * Clean and refactor part3 * TDD and 2 beers * TDD and 2 beers * Complete state endpoints * You can only set what is already present * Use only JSON as response * Use default logger * Improve auth extraction * Extend api command with more args * Adjust API messages
This commit is contained in:
@@ -35,6 +35,7 @@ jobs:
|
||||
python-version: '3.12'
|
||||
- run: python install.py --onnxruntime default --skip-conda
|
||||
- run: pip install pytest
|
||||
- run: pip install httpx
|
||||
- run: pytest
|
||||
report:
|
||||
needs: test
|
||||
@@ -52,6 +53,7 @@ jobs:
|
||||
- run: pip install coveralls
|
||||
- run: pip install pytest
|
||||
- run: pip install pytest-cov
|
||||
- run: pip install httpx
|
||||
- run: pytest tests --cov facefusion
|
||||
- run: coveralls --service github
|
||||
env:
|
||||
|
||||
@@ -31,6 +31,7 @@ commands:
|
||||
batch-run run the program in batch mode
|
||||
force-download force automate downloads and exit
|
||||
benchmark benchmark the program
|
||||
api start the API server
|
||||
job-list list jobs by status
|
||||
job-create create a drafted job
|
||||
job-submit submit a drafted job to become a queued job
|
||||
|
||||
@@ -112,6 +112,10 @@ benchmark_mode =
|
||||
benchmark_resolutions =
|
||||
benchmark_cycle_count =
|
||||
|
||||
[api]
|
||||
api_host =
|
||||
api_port =
|
||||
|
||||
[execution]
|
||||
execution_device_ids =
|
||||
execution_providers =
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Route
|
||||
|
||||
from facefusion.apis.session import create_session
|
||||
from facefusion.apis.session import create_session_guard
|
||||
from facefusion.apis.session import destroy_session
|
||||
from facefusion.apis.session import get_session
|
||||
from facefusion.apis.session import refresh_session
|
||||
from facefusion.apis.state import get_state
|
||||
from facefusion.apis.state import set_state
|
||||
|
||||
|
||||
def create_api() -> Starlette:
|
||||
session_guard = Middleware(create_session_guard)
|
||||
routes =\
|
||||
[
|
||||
Route('/session', create_session, methods = [ 'POST' ]),
|
||||
Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/session', refresh_session, methods = [ 'PUT' ]),
|
||||
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
api = Starlette(routes = routes)
|
||||
api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ])
|
||||
|
||||
return api
|
||||
@@ -0,0 +1,12 @@
|
||||
from facefusion.types import Locals
|
||||
|
||||
LOCALS : Locals =\
|
||||
{
|
||||
'en':
|
||||
{
|
||||
'ok': 'ok',
|
||||
'something_went_wrong': 'something went wrong',
|
||||
'invalid_access_token': 'invalid access token',
|
||||
'invalid_refresh_token': 'invalid refresh token'
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED, HTTP_426_UPGRADE_REQUIRED
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from facefusion import session_manager, translator
|
||||
from facefusion.types import Token
|
||||
|
||||
|
||||
async def create_session(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
|
||||
if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'):
|
||||
session = session_manager.create_session()
|
||||
session_manager.set_session(session.get('access_token'), session)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token')
|
||||
}, status_code = HTTP_201_CREATED)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
async def get_session(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.headers)
|
||||
|
||||
if access_token:
|
||||
session = session_manager.get_session(access_token)
|
||||
|
||||
if session:
|
||||
return JSONResponse(
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token'),
|
||||
'created_at': session.get('created_at').isoformat(),
|
||||
'expires_at': session.get('expires_at').isoformat()
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
async def refresh_session(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
|
||||
for access_token, session in session_manager.SESSIONS.items():
|
||||
if session.get('refresh_token') == body.get('refresh_token'):
|
||||
session_manager.clear_session(access_token)
|
||||
session = session_manager.create_session()
|
||||
session_manager.set_session(session.get('access_token'), session)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token')
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
async def destroy_session(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.headers)
|
||||
|
||||
if access_token:
|
||||
session = session_manager.get_session(access_token)
|
||||
|
||||
if session:
|
||||
session_manager.clear_session(access_token)
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('ok', __package__)
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
def create_session_guard(app : ASGIApp) -> ASGIApp:
|
||||
async def middleware(scope : Scope, receive : Receive, send : Send) -> None:
|
||||
access_token = extract_access_token(Headers(scope = scope))
|
||||
|
||||
if access_token and session_manager.validate_session(access_token):
|
||||
return await app(scope, receive, send)
|
||||
|
||||
if access_token:
|
||||
session = session_manager.get_session(access_token)
|
||||
|
||||
if session:
|
||||
response = JSONResponse(
|
||||
{
|
||||
'message': translator.get('invalid_access_token', __package__)
|
||||
}, status_code = HTTP_426_UPGRADE_REQUIRED)
|
||||
|
||||
return await response(scope, receive, send)
|
||||
|
||||
response = JSONResponse(
|
||||
{
|
||||
'message': translator.get('invalid_access_token', __package__)
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
return await response(scope, receive, send)
|
||||
|
||||
return middleware
|
||||
|
||||
|
||||
def extract_access_token(headers : Headers) -> Optional[Token]:
|
||||
auth_header = headers.get('Authorization')
|
||||
|
||||
if auth_header:
|
||||
auth_prefix, _, access_token = auth_header.partition(' ')
|
||||
|
||||
if auth_prefix.lower() == 'bearer' and access_token:
|
||||
return access_token
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import get_args
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
from facefusion import state_manager
|
||||
from facefusion.types import StateKey
|
||||
|
||||
|
||||
async def get_state(request : Request) -> JSONResponse:
|
||||
return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def set_state(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
|
||||
for key, value in body.items():
|
||||
if key in get_args(StateKey):
|
||||
state_manager.set_item(key, value)
|
||||
|
||||
return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK)
|
||||
|
||||
@@ -115,6 +115,9 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
|
||||
apply_state_item('benchmark_mode', args.get('benchmark_mode'))
|
||||
apply_state_item('benchmark_resolutions', args.get('benchmark_resolutions'))
|
||||
apply_state_item('benchmark_cycle_count', args.get('benchmark_cycle_count'))
|
||||
# api
|
||||
apply_state_item('api_host', args.get('api_host'))
|
||||
apply_state_item('api_port', args.get('api_port'))
|
||||
# memory
|
||||
apply_state_item('video_memory_strategy', args.get('video_memory_strategy'))
|
||||
apply_state_item('system_memory_limit', args.get('system_memory_limit'))
|
||||
|
||||
@@ -5,7 +5,10 @@ import signal
|
||||
import sys
|
||||
from time import time
|
||||
|
||||
import uvicorn
|
||||
|
||||
from facefusion import benchmarker, cli_helper, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, hash_helper, logger, state_manager, translator, voice_extractor
|
||||
from facefusion.apis.core import create_api
|
||||
from facefusion.args import apply_args, collect_job_args, reduce_job_args, reduce_step_args
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources
|
||||
from facefusion.exit_helper import hard_exit, signal_exit
|
||||
@@ -55,6 +58,11 @@ def route(args : Args) -> None:
|
||||
hard_exit(2)
|
||||
benchmarker.render()
|
||||
|
||||
if state_manager.get_item('command') == 'api':
|
||||
logger.info(translator.get('api_started').format(host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port')), __name__)
|
||||
uvicorn.run(create_api(), host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port'))
|
||||
hard_exit(1)
|
||||
|
||||
if state_manager.get_item('command') in [ 'job-list', 'job-create', 'job-submit', 'job-submit-all', 'job-delete', 'job-delete-all', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step' ]:
|
||||
if not job_manager.init_jobs(state_manager.get_item('jobs_path')):
|
||||
hard_exit(1)
|
||||
|
||||
@@ -48,10 +48,9 @@ LOCALES : Locales =\
|
||||
'no_source_face_detected': 'no source face detected',
|
||||
'processor_not_loaded': 'processor {processor} could not be loaded',
|
||||
'processor_not_implemented': 'processor {processor} not implemented correctly',
|
||||
'ui_layout_not_loaded': 'ui layout {ui_layout} could not be loaded',
|
||||
'ui_layout_not_implemented': 'ui layout {ui_layout} not implemented correctly',
|
||||
'stream_not_loaded': 'stream {stream_mode} could not be loaded',
|
||||
'stream_not_supported': 'stream not supported',
|
||||
'api_started': 'started API on {host}:{port}',
|
||||
'job_created': 'job {job_id} created',
|
||||
'job_not_created': 'job {job_id} not created',
|
||||
'job_submitted': 'job {job_id} submitted',
|
||||
@@ -154,6 +153,8 @@ LOCALES : Locales =\
|
||||
'benchmark_mode': 'choose the benchmark mode',
|
||||
'benchmark_resolutions': 'choose the resolutions for the benchmarks (choices: {choices}, ...)',
|
||||
'benchmark_cycle_count': 'specify the amount of cycles per benchmark',
|
||||
'api_host': 'specify the API host',
|
||||
'api_port': 'specify the API port',
|
||||
'execution_device_ids': 'specify the devices used for processing',
|
||||
'execution_providers': 'inference using different providers (choices: {choices}, ...)',
|
||||
'execution_thread_count': 'specify the amount of parallel threads while processing',
|
||||
@@ -165,6 +166,7 @@ LOCALES : Locales =\
|
||||
'batch_run': 'run the program in batch mode',
|
||||
'force_download': 'force automate downloads and exit',
|
||||
'benchmark': 'benchmark the program',
|
||||
'api': 'start the API server',
|
||||
'job_id': 'specify the job id',
|
||||
'job_status': 'specify the job status',
|
||||
'step_index': 'specify the step index',
|
||||
|
||||
@@ -220,6 +220,14 @@ def create_benchmark_program() -> ArgumentParser:
|
||||
return program
|
||||
|
||||
|
||||
def create_api_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
group_api = program.add_argument_group('api')
|
||||
group_api.add_argument('--api-host', help = translator.get('help.api_host'), default = config.get_str_value('api', 'api_host', '127.0.0.1'))
|
||||
group_api.add_argument('--api-port', help = translator.get('help.api_port'), type = int, default = config.get_int_value('api', 'api_port', '8000'))
|
||||
return program
|
||||
|
||||
|
||||
def create_execution_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
available_execution_providers = get_available_execution_providers()
|
||||
@@ -292,6 +300,7 @@ def create_program() -> ArgumentParser:
|
||||
sub_program.add_parser('batch-run', help = translator.get('help.batch_run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_pattern_program(), create_target_pattern_program(), create_output_pattern_program(), collect_step_program(), collect_job_program() ], formatter_class = create_help_formatter_large)
|
||||
sub_program.add_parser('force-download', help = translator.get('help.force_download'), parents = [ create_download_providers_program(), create_download_scope_program(), create_log_level_program() ], formatter_class = create_help_formatter_large)
|
||||
sub_program.add_parser('benchmark', help = translator.get('help.benchmark'), parents = [ create_temp_path_program(), collect_step_program(), create_benchmark_program(), collect_job_program() ], formatter_class = create_help_formatter_large)
|
||||
sub_program.add_parser('api', help = translator.get('help.api'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_api_program(), collect_job_program() ], formatter_class = create_help_formatter_large)
|
||||
# job manager
|
||||
sub_program.add_parser('job-list', help = translator.get('help.job_list'), parents = [ create_job_status_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large)
|
||||
sub_program.add_parser('job-create', help = translator.get('help.job_create'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large)
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.types import Session, Token
|
||||
|
||||
SESSIONS : Dict[Token, Session] = {}
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
session : Session =\
|
||||
{
|
||||
'access_token': secrets.token_urlsafe(128),
|
||||
'refresh_token': secrets.token_urlsafe(128),
|
||||
'created_at': datetime.now(),
|
||||
'expires_at': datetime.now() + timedelta(minutes = 10)
|
||||
}
|
||||
|
||||
return session
|
||||
|
||||
|
||||
def get_session(access_token : Token) -> Optional[Session]:
|
||||
return SESSIONS.get(access_token)
|
||||
|
||||
|
||||
def set_session(access_token : Token, session : Session) -> None:
|
||||
SESSIONS[access_token] = session
|
||||
|
||||
|
||||
def validate_session(access_token : Token) -> bool:
|
||||
session = get_session(access_token)
|
||||
return session and datetime.now() <= session.get('expires_at')
|
||||
|
||||
|
||||
def clear_session(access_token : Token) -> None:
|
||||
if access_token in SESSIONS:
|
||||
del SESSIONS[access_token]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeAlias, TypedDict
|
||||
|
||||
import cv2
|
||||
@@ -101,6 +102,15 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool]
|
||||
|
||||
Content : TypeAlias = Dict[str, Any]
|
||||
|
||||
Token : TypeAlias = str
|
||||
Session = TypedDict('Session',
|
||||
{
|
||||
'access_token': Token,
|
||||
'refresh_token': Token,
|
||||
'created_at': datetime,
|
||||
'expires_at': datetime
|
||||
})
|
||||
|
||||
Command : TypeAlias = str
|
||||
CommandSet : TypeAlias = Dict[str, List[Command]]
|
||||
|
||||
|
||||
@@ -5,3 +5,5 @@ opencv-python==4.12.0.88
|
||||
psutil==7.1.3
|
||||
tqdm==4.67.1
|
||||
scipy==1.16.3
|
||||
starlette==0.50.0
|
||||
uvicorn==0.34.0
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, session_manager
|
||||
from facefusion.apis.core import create_api
|
||||
from facefusion.types import Session
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
def test_client() -> Iterator[TestClient]:
|
||||
with TestClient(create_api()) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function', autouse = True)
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
|
||||
|
||||
def test_create_session(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
assert session_manager.get_session(create_session_body.get('access_token'))
|
||||
assert create_session_response.status_code == 201
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'api_key': 'TEST',
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
|
||||
assert create_session_response.status_code == 401
|
||||
|
||||
os.environ['FACEFUSION_API_KEY'] = 'TEST'
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'api_key': 'INVALID',
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
|
||||
assert create_session_response.status_code == 401
|
||||
|
||||
os.environ['FACEFUSION_API_KEY'] = 'TEST'
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'api_key': 'TEST',
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
|
||||
assert create_session_response.status_code == 201
|
||||
|
||||
del os.environ['FACEFUSION_API_KEY']
|
||||
|
||||
|
||||
def test_get_session(test_client : TestClient) -> None:
|
||||
get_session_response = test_client.get('/session')
|
||||
|
||||
assert get_session_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
get_session_response = test_client.get('/session', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
|
||||
assert get_session_response.status_code == 200
|
||||
|
||||
access_token = create_session_body.get('access_token')
|
||||
session : Session = session_manager.get_session(access_token)
|
||||
session_manager.set_session(access_token,
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token'),
|
||||
'created_at': session.get('created_at'),
|
||||
'expires_at': session.get('expires_at') - timedelta(hours = 1)
|
||||
})
|
||||
|
||||
get_session_response = test_client.get('/session', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
assert get_session_response.status_code == 426
|
||||
|
||||
|
||||
def test_refresh_session(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
refresh_session_response = test_client.put('/session', json =
|
||||
{
|
||||
'refresh_token': 'INVALID'
|
||||
})
|
||||
|
||||
assert refresh_session_response.status_code == 401
|
||||
|
||||
refresh_session_response = test_client.put('/session', json =
|
||||
{
|
||||
'refresh_token': create_session_body.get('refresh_token')
|
||||
})
|
||||
refresh_session_body = refresh_session_response.json()
|
||||
|
||||
assert session_manager.get_session(create_session_body.get('access_token')) is None
|
||||
|
||||
assert session_manager.get_session(refresh_session_body.get('access_token'))
|
||||
|
||||
assert refresh_session_response.status_code == 200
|
||||
|
||||
refresh_session_response = test_client.put('/session', json =
|
||||
{
|
||||
'refresh_token': create_session_body.get('refresh_token')
|
||||
})
|
||||
|
||||
assert refresh_session_response.status_code == 401
|
||||
|
||||
|
||||
def test_destroy_session(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
delete_session_response = test_client.delete('/session', headers =
|
||||
{
|
||||
'Authorization': 'Bearer INVALID'
|
||||
})
|
||||
|
||||
assert delete_session_response.status_code == 401
|
||||
|
||||
delete_session_response = test_client.delete('/session', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
|
||||
assert session_manager.get_session(create_session_body.get('access_token')) is None
|
||||
|
||||
assert delete_session_response.status_code == 200
|
||||
@@ -0,0 +1,80 @@
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, session_manager, state_manager
|
||||
from facefusion.apis.core import create_api
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
def test_client() -> Iterator[TestClient]:
|
||||
state_manager.init_item('execution_providers', [ 'cpu' ])
|
||||
|
||||
with TestClient(create_api()) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function', autouse = True)
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
|
||||
|
||||
def test_get_state(test_client : TestClient) -> None:
|
||||
get_state_response = test_client.get('/state')
|
||||
|
||||
assert get_state_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
get_state_response = test_client.get('/state', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
get_state_body = get_state_response.json()
|
||||
|
||||
assert get_state_body.get('execution_providers') == [ 'cpu' ]
|
||||
assert get_state_response.status_code == 200
|
||||
|
||||
|
||||
def test_set_state(test_client : TestClient) -> None:
|
||||
set_state_response = test_client.put('/state', json =
|
||||
{
|
||||
'execution_providers': [ 'cuda' ]
|
||||
})
|
||||
|
||||
assert set_state_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
set_state_response = test_client.put('/state', json =
|
||||
{
|
||||
'execution_providers': [ 'cuda' ]
|
||||
}, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
set_state_body = set_state_response.json()
|
||||
|
||||
assert set_state_body.get('execution_providers') == [ 'cuda' ]
|
||||
assert set_state_response.status_code == 200
|
||||
|
||||
set_state_response = test_client.put('/state', json =
|
||||
{
|
||||
'invalid': 'invalid'
|
||||
}, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
set_state_body = set_state_response.json()
|
||||
|
||||
assert set_state_body.get('invalid') is None
|
||||
assert set_state_response.status_code == 200
|
||||
@@ -106,7 +106,7 @@ def test_find_jobs() -> None:
|
||||
|
||||
assert 'job-test-find-jobs-1' in find_jobs('drafted')
|
||||
assert 'job-test-find-jobs-2' in find_jobs('drafted')
|
||||
assert not find_jobs('queued')
|
||||
assert find_jobs('queued') == {}
|
||||
|
||||
move_job_file('job-test-find-jobs-1', 'queued')
|
||||
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ from facefusion.json import read_json, write_json
|
||||
def test_read_json() -> None:
|
||||
_, json_path = tempfile.mkstemp(suffix = '.json')
|
||||
|
||||
assert not read_json(json_path)
|
||||
assert read_json(json_path) is None
|
||||
|
||||
write_json(json_path, {})
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import secrets
|
||||
from datetime import timedelta
|
||||
|
||||
from facefusion.session_manager import clear_session, create_session, get_session, set_session, validate_session
|
||||
|
||||
|
||||
def test_get_and_set_session() -> None:
|
||||
session = create_session()
|
||||
access_token = secrets.token_urlsafe(128)
|
||||
|
||||
set_session(access_token, session)
|
||||
|
||||
assert get_session(access_token) == session
|
||||
|
||||
|
||||
def test_validate_session() -> None:
|
||||
session = create_session()
|
||||
access_token = secrets.token_urlsafe(128)
|
||||
|
||||
set_session(access_token, session)
|
||||
|
||||
assert validate_session(access_token) is True
|
||||
|
||||
set_session(access_token,
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token'),
|
||||
'created_at': session.get('created_at'),
|
||||
'expires_at': session.get('expires_at') - timedelta(hours = 1)
|
||||
})
|
||||
|
||||
assert validate_session(access_token) is False
|
||||
|
||||
|
||||
def test_clear_session() -> None:
|
||||
session = create_session()
|
||||
access_token = secrets.token_urlsafe(128)
|
||||
|
||||
set_session(access_token, session)
|
||||
|
||||
assert validate_session(access_token) is True
|
||||
|
||||
clear_session(access_token)
|
||||
|
||||
assert validate_session(access_token) is None
|
||||
Reference in New Issue
Block a user