From 3daf0496842764dcacc57f08bd786c58a0052ade Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Mon, 17 Nov 2025 08:36:31 +0100 Subject: [PATCH] Local API (#982) * Introduce API scelleton * Raw impl for session * Simple state endpoint * Apply _body naming * Finalize session testing and comment out tons of useless code * Clean and refactor part1 * Clean and refactor part2 * Clean and refactor part2 * Clean and refactor part2 * Clean and refactor part2 * Refactor middleware * Refactor middleware * Clean and refactor part3 * TDD and 2 beers * TDD and 2 beers * Complete state endpoints * You can only set what is already present * Use only JSON as response * Use default logger * Improve auth extraction * Extend api command with more args * Adjust API messages --- .github/workflows/ci.yml | 2 + README.md | 1 + facefusion.ini | 4 + facefusion/apis/__init__.py | 0 facefusion/apis/core.py | 30 +++++++ facefusion/apis/locals.py | 12 +++ facefusion/apis/session.py | 131 +++++++++++++++++++++++++++++ facefusion/apis/state.py | 23 +++++ facefusion/args.py | 3 + facefusion/core.py | 8 ++ facefusion/locales.py | 6 +- facefusion/program.py | 9 ++ facefusion/session_manager.py | 39 +++++++++ facefusion/types.py | 10 +++ requirements.txt | 2 + tests/test_api_session.py | 154 ++++++++++++++++++++++++++++++++++ tests/test_api_state.py | 80 ++++++++++++++++++ tests/test_job_manager.py | 2 +- tests/test_json.py | 2 +- tests/test_session_manager.py | 45 ++++++++++ 20 files changed, 559 insertions(+), 4 deletions(-) create mode 100644 facefusion/apis/__init__.py create mode 100644 facefusion/apis/core.py create mode 100644 facefusion/apis/locals.py create mode 100644 facefusion/apis/session.py create mode 100644 facefusion/apis/state.py create mode 100644 facefusion/session_manager.py create mode 100644 tests/test_api_session.py create mode 100644 tests/test_api_state.py create mode 100644 tests/test_session_manager.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 829cc9be..8d1f9d86 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,7 @@ jobs: python-version: '3.12' - run: python install.py --onnxruntime default --skip-conda - run: pip install pytest + - run: pip install httpx - run: pytest report: needs: test @@ -52,6 +53,7 @@ jobs: - run: pip install coveralls - run: pip install pytest - run: pip install pytest-cov + - run: pip install httpx - run: pytest tests --cov facefusion - run: coveralls --service github env: diff --git a/README.md b/README.md index 2942ff05..67503790 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ commands: batch-run run the program in batch mode force-download force automate downloads and exit benchmark benchmark the program + api start the API server job-list list jobs by status job-create create a drafted job job-submit submit a drafted job to become a queued job diff --git a/facefusion.ini b/facefusion.ini index d74e8364..bcd9a005 100644 --- a/facefusion.ini +++ b/facefusion.ini @@ -112,6 +112,10 @@ benchmark_mode = benchmark_resolutions = benchmark_cycle_count = +[api] +api_host = +api_port = + [execution] execution_device_ids = execution_providers = diff --git a/facefusion/apis/__init__.py b/facefusion/apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py new file mode 100644 index 00000000..67a08502 --- /dev/null +++ b/facefusion/apis/core.py @@ -0,0 +1,30 @@ +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +from starlette.routing import Route + +from facefusion.apis.session import create_session +from facefusion.apis.session import create_session_guard +from facefusion.apis.session import destroy_session +from facefusion.apis.session import get_session +from facefusion.apis.session import refresh_session +from facefusion.apis.state import get_state +from facefusion.apis.state import set_state + + +def create_api() -> Starlette: + session_guard = Middleware(create_session_guard) + routes =\ + [ + Route('/session', create_session, methods = [ 'POST' ]), + Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]), + Route('/session', refresh_session, methods = [ 'PUT' ]), + Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]), + Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]), + Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]) + ] + + api = Starlette(routes = routes) + api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ]) + + return api diff --git a/facefusion/apis/locals.py b/facefusion/apis/locals.py new file mode 100644 index 00000000..97be9e00 --- /dev/null +++ b/facefusion/apis/locals.py @@ -0,0 +1,12 @@ +from facefusion.types import Locals + +LOCALS : Locals =\ +{ + 'en': + { + 'ok': 'ok', + 'something_went_wrong': 'something went wrong', + 'invalid_access_token': 'invalid access token', + 'invalid_refresh_token': 'invalid refresh token' + } +} diff --git a/facefusion/apis/session.py b/facefusion/apis/session.py new file mode 100644 index 00000000..d18946dc --- /dev/null +++ b/facefusion/apis/session.py @@ -0,0 +1,131 @@ +import os +from typing import Optional + +from starlette.datastructures import Headers +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED, HTTP_426_UPGRADE_REQUIRED +from starlette.types import ASGIApp, Receive, Scope, Send + +from facefusion import session_manager, translator +from facefusion.types import Token + + +async def create_session(request : Request) -> JSONResponse: + body = await request.json() + + if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'): + session = session_manager.create_session() + session_manager.set_session(session.get('access_token'), session) + + return JSONResponse( + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token') + }, status_code = HTTP_201_CREATED) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +async def get_session(request : Request) -> JSONResponse: + access_token = extract_access_token(request.headers) + + if access_token: + session = session_manager.get_session(access_token) + + if session: + return JSONResponse( + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token'), + 'created_at': session.get('created_at').isoformat(), + 'expires_at': session.get('expires_at').isoformat() + }, status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +async def refresh_session(request : Request) -> JSONResponse: + body = await request.json() + + for access_token, session in session_manager.SESSIONS.items(): + if session.get('refresh_token') == body.get('refresh_token'): + session_manager.clear_session(access_token) + session = session_manager.create_session() + session_manager.set_session(session.get('access_token'), session) + + return JSONResponse( + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token') + }, status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +async def destroy_session(request : Request) -> JSONResponse: + access_token = extract_access_token(request.headers) + + if access_token: + session = session_manager.get_session(access_token) + + if session: + session_manager.clear_session(access_token) + return JSONResponse( + { + 'message': translator.get('ok', __package__) + }, status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +def create_session_guard(app : ASGIApp) -> ASGIApp: + async def middleware(scope : Scope, receive : Receive, send : Send) -> None: + access_token = extract_access_token(Headers(scope = scope)) + + if access_token and session_manager.validate_session(access_token): + return await app(scope, receive, send) + + if access_token: + session = session_manager.get_session(access_token) + + if session: + response = JSONResponse( + { + 'message': translator.get('invalid_access_token', __package__) + }, status_code = HTTP_426_UPGRADE_REQUIRED) + + return await response(scope, receive, send) + + response = JSONResponse( + { + 'message': translator.get('invalid_access_token', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + return await response(scope, receive, send) + + return middleware + + +def extract_access_token(headers : Headers) -> Optional[Token]: + auth_header = headers.get('Authorization') + + if auth_header: + auth_prefix, _, access_token = auth_header.partition(' ') + + if auth_prefix.lower() == 'bearer' and access_token: + return access_token + + return None diff --git a/facefusion/apis/state.py b/facefusion/apis/state.py new file mode 100644 index 00000000..682240c9 --- /dev/null +++ b/facefusion/apis/state.py @@ -0,0 +1,23 @@ +from typing import get_args + +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_200_OK + +from facefusion import state_manager +from facefusion.types import StateKey + + +async def get_state(request : Request) -> JSONResponse: + return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK) + + +async def set_state(request : Request) -> JSONResponse: + body = await request.json() + + for key, value in body.items(): + if key in get_args(StateKey): + state_manager.set_item(key, value) + + return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK) + diff --git a/facefusion/args.py b/facefusion/args.py index 269711ec..095995ee 100644 --- a/facefusion/args.py +++ b/facefusion/args.py @@ -115,6 +115,9 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('benchmark_mode', args.get('benchmark_mode')) apply_state_item('benchmark_resolutions', args.get('benchmark_resolutions')) apply_state_item('benchmark_cycle_count', args.get('benchmark_cycle_count')) + # api + apply_state_item('api_host', args.get('api_host')) + apply_state_item('api_port', args.get('api_port')) # memory apply_state_item('video_memory_strategy', args.get('video_memory_strategy')) apply_state_item('system_memory_limit', args.get('system_memory_limit')) diff --git a/facefusion/core.py b/facefusion/core.py index 228768cd..a8b5a127 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -5,7 +5,10 @@ import signal import sys from time import time +import uvicorn + from facefusion import benchmarker, cli_helper, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, hash_helper, logger, state_manager, translator, voice_extractor +from facefusion.apis.core import create_api from facefusion.args import apply_args, collect_job_args, reduce_job_args, reduce_step_args from facefusion.download import conditional_download_hashes, conditional_download_sources from facefusion.exit_helper import hard_exit, signal_exit @@ -55,6 +58,11 @@ def route(args : Args) -> None: hard_exit(2) benchmarker.render() + if state_manager.get_item('command') == 'api': + logger.info(translator.get('api_started').format(host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port')), __name__) + uvicorn.run(create_api(), host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port')) + hard_exit(1) + if state_manager.get_item('command') in [ 'job-list', 'job-create', 'job-submit', 'job-submit-all', 'job-delete', 'job-delete-all', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step' ]: if not job_manager.init_jobs(state_manager.get_item('jobs_path')): hard_exit(1) diff --git a/facefusion/locales.py b/facefusion/locales.py index 1b8696bf..bfcc8a39 100644 --- a/facefusion/locales.py +++ b/facefusion/locales.py @@ -48,10 +48,9 @@ LOCALES : Locales =\ 'no_source_face_detected': 'no source face detected', 'processor_not_loaded': 'processor {processor} could not be loaded', 'processor_not_implemented': 'processor {processor} not implemented correctly', - 'ui_layout_not_loaded': 'ui layout {ui_layout} could not be loaded', - 'ui_layout_not_implemented': 'ui layout {ui_layout} not implemented correctly', 'stream_not_loaded': 'stream {stream_mode} could not be loaded', 'stream_not_supported': 'stream not supported', + 'api_started': 'started API on {host}:{port}', 'job_created': 'job {job_id} created', 'job_not_created': 'job {job_id} not created', 'job_submitted': 'job {job_id} submitted', @@ -154,6 +153,8 @@ LOCALES : Locales =\ 'benchmark_mode': 'choose the benchmark mode', 'benchmark_resolutions': 'choose the resolutions for the benchmarks (choices: {choices}, ...)', 'benchmark_cycle_count': 'specify the amount of cycles per benchmark', + 'api_host': 'specify the API host', + 'api_port': 'specify the API port', 'execution_device_ids': 'specify the devices used for processing', 'execution_providers': 'inference using different providers (choices: {choices}, ...)', 'execution_thread_count': 'specify the amount of parallel threads while processing', @@ -165,6 +166,7 @@ LOCALES : Locales =\ 'batch_run': 'run the program in batch mode', 'force_download': 'force automate downloads and exit', 'benchmark': 'benchmark the program', + 'api': 'start the API server', 'job_id': 'specify the job id', 'job_status': 'specify the job status', 'step_index': 'specify the step index', diff --git a/facefusion/program.py b/facefusion/program.py index 2aa354c3..61e60917 100755 --- a/facefusion/program.py +++ b/facefusion/program.py @@ -220,6 +220,14 @@ def create_benchmark_program() -> ArgumentParser: return program +def create_api_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_api = program.add_argument_group('api') + group_api.add_argument('--api-host', help = translator.get('help.api_host'), default = config.get_str_value('api', 'api_host', '127.0.0.1')) + group_api.add_argument('--api-port', help = translator.get('help.api_port'), type = int, default = config.get_int_value('api', 'api_port', '8000')) + return program + + def create_execution_program() -> ArgumentParser: program = ArgumentParser(add_help = False) available_execution_providers = get_available_execution_providers() @@ -292,6 +300,7 @@ def create_program() -> ArgumentParser: sub_program.add_parser('batch-run', help = translator.get('help.batch_run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_pattern_program(), create_target_pattern_program(), create_output_pattern_program(), collect_step_program(), collect_job_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('force-download', help = translator.get('help.force_download'), parents = [ create_download_providers_program(), create_download_scope_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('benchmark', help = translator.get('help.benchmark'), parents = [ create_temp_path_program(), collect_step_program(), create_benchmark_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('api', help = translator.get('help.api'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_api_program(), collect_job_program() ], formatter_class = create_help_formatter_large) # job manager sub_program.add_parser('job-list', help = translator.get('help.job_list'), parents = [ create_job_status_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('job-create', help = translator.get('help.job_create'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) diff --git a/facefusion/session_manager.py b/facefusion/session_manager.py new file mode 100644 index 00000000..8a27abca --- /dev/null +++ b/facefusion/session_manager.py @@ -0,0 +1,39 @@ +import secrets +from datetime import datetime, timedelta +from typing import Dict +from typing import Optional + +from facefusion.types import Session, Token + +SESSIONS : Dict[Token, Session] = {} + + +def create_session() -> Session: + session : Session =\ + { + 'access_token': secrets.token_urlsafe(128), + 'refresh_token': secrets.token_urlsafe(128), + 'created_at': datetime.now(), + 'expires_at': datetime.now() + timedelta(minutes = 10) + } + + return session + + +def get_session(access_token : Token) -> Optional[Session]: + return SESSIONS.get(access_token) + + +def set_session(access_token : Token, session : Session) -> None: + SESSIONS[access_token] = session + + +def validate_session(access_token : Token) -> bool: + session = get_session(access_token) + return session and datetime.now() <= session.get('expires_at') + + +def clear_session(access_token : Token) -> None: + if access_token in SESSIONS: + del SESSIONS[access_token] + diff --git a/facefusion/types.py b/facefusion/types.py index 05933461..077147ee 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -1,4 +1,5 @@ from collections import namedtuple +from datetime import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeAlias, TypedDict import cv2 @@ -101,6 +102,15 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool] Content : TypeAlias = Dict[str, Any] +Token : TypeAlias = str +Session = TypedDict('Session', +{ + 'access_token': Token, + 'refresh_token': Token, + 'created_at': datetime, + 'expires_at': datetime +}) + Command : TypeAlias = str CommandSet : TypeAlias = Dict[str, List[Command]] diff --git a/requirements.txt b/requirements.txt index 22122b95..1865aba8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ opencv-python==4.12.0.88 psutil==7.1.3 tqdm==4.67.1 scipy==1.16.3 +starlette==0.50.0 +uvicorn==0.34.0 diff --git a/tests/test_api_session.py b/tests/test_api_session.py new file mode 100644 index 00000000..29f83eb9 --- /dev/null +++ b/tests/test_api_session.py @@ -0,0 +1,154 @@ +import os +from datetime import timedelta +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from facefusion import metadata, session_manager +from facefusion.apis.core import create_api +from facefusion.types import Session + + +@pytest.fixture(scope = 'module') +def test_client() -> Iterator[TestClient]: + with TestClient(create_api()) as test_client: + yield test_client + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + session_manager.SESSIONS.clear() + + +def test_create_session(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + assert session_manager.get_session(create_session_body.get('access_token')) + assert create_session_response.status_code == 201 + + create_session_response = test_client.post('/session', json = + { + 'api_key': 'TEST', + 'client_version': metadata.get('version') + }) + + assert create_session_response.status_code == 401 + + os.environ['FACEFUSION_API_KEY'] = 'TEST' + create_session_response = test_client.post('/session', json = + { + 'api_key': 'INVALID', + 'client_version': metadata.get('version') + }) + + assert create_session_response.status_code == 401 + + os.environ['FACEFUSION_API_KEY'] = 'TEST' + create_session_response = test_client.post('/session', json = + { + 'api_key': 'TEST', + 'client_version': metadata.get('version') + }) + + assert create_session_response.status_code == 201 + + del os.environ['FACEFUSION_API_KEY'] + + +def test_get_session(test_client : TestClient) -> None: + get_session_response = test_client.get('/session') + + assert get_session_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + get_session_response = test_client.get('/session', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert get_session_response.status_code == 200 + + access_token = create_session_body.get('access_token') + session : Session = session_manager.get_session(access_token) + session_manager.set_session(access_token, + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token'), + 'created_at': session.get('created_at'), + 'expires_at': session.get('expires_at') - timedelta(hours = 1) + }) + + get_session_response = test_client.get('/session', headers = + { + 'Authorization': 'Bearer ' + access_token + }) + + assert get_session_response.status_code == 426 + + +def test_refresh_session(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + refresh_session_response = test_client.put('/session', json = + { + 'refresh_token': 'INVALID' + }) + + assert refresh_session_response.status_code == 401 + + refresh_session_response = test_client.put('/session', json = + { + 'refresh_token': create_session_body.get('refresh_token') + }) + refresh_session_body = refresh_session_response.json() + + assert session_manager.get_session(create_session_body.get('access_token')) is None + + assert session_manager.get_session(refresh_session_body.get('access_token')) + + assert refresh_session_response.status_code == 200 + + refresh_session_response = test_client.put('/session', json = + { + 'refresh_token': create_session_body.get('refresh_token') + }) + + assert refresh_session_response.status_code == 401 + + +def test_destroy_session(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + delete_session_response = test_client.delete('/session', headers = + { + 'Authorization': 'Bearer INVALID' + }) + + assert delete_session_response.status_code == 401 + + delete_session_response = test_client.delete('/session', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert session_manager.get_session(create_session_body.get('access_token')) is None + + assert delete_session_response.status_code == 200 diff --git a/tests/test_api_state.py b/tests/test_api_state.py new file mode 100644 index 00000000..d15797c1 --- /dev/null +++ b/tests/test_api_state.py @@ -0,0 +1,80 @@ +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from facefusion import metadata, session_manager, state_manager +from facefusion.apis.core import create_api + + +@pytest.fixture(scope = 'module') +def test_client() -> Iterator[TestClient]: + state_manager.init_item('execution_providers', [ 'cpu' ]) + + with TestClient(create_api()) as test_client: + yield test_client + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + session_manager.SESSIONS.clear() + + +def test_get_state(test_client : TestClient) -> None: + get_state_response = test_client.get('/state') + + assert get_state_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + get_state_response = test_client.get('/state', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + get_state_body = get_state_response.json() + + assert get_state_body.get('execution_providers') == [ 'cpu' ] + assert get_state_response.status_code == 200 + + +def test_set_state(test_client : TestClient) -> None: + set_state_response = test_client.put('/state', json = + { + 'execution_providers': [ 'cuda' ] + }) + + assert set_state_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + set_state_response = test_client.put('/state', json = + { + 'execution_providers': [ 'cuda' ] + }, headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + set_state_body = set_state_response.json() + + assert set_state_body.get('execution_providers') == [ 'cuda' ] + assert set_state_response.status_code == 200 + + set_state_response = test_client.put('/state', json = + { + 'invalid': 'invalid' + }, headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + set_state_body = set_state_response.json() + + assert set_state_body.get('invalid') is None + assert set_state_response.status_code == 200 diff --git a/tests/test_job_manager.py b/tests/test_job_manager.py index 05b1fedc..27164146 100644 --- a/tests/test_job_manager.py +++ b/tests/test_job_manager.py @@ -106,7 +106,7 @@ def test_find_jobs() -> None: assert 'job-test-find-jobs-1' in find_jobs('drafted') assert 'job-test-find-jobs-2' in find_jobs('drafted') - assert not find_jobs('queued') + assert find_jobs('queued') == {} move_job_file('job-test-find-jobs-1', 'queued') diff --git a/tests/test_json.py b/tests/test_json.py index c1d8a387..9bdfd157 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -6,7 +6,7 @@ from facefusion.json import read_json, write_json def test_read_json() -> None: _, json_path = tempfile.mkstemp(suffix = '.json') - assert not read_json(json_path) + assert read_json(json_path) is None write_json(json_path, {}) diff --git a/tests/test_session_manager.py b/tests/test_session_manager.py new file mode 100644 index 00000000..c886c541 --- /dev/null +++ b/tests/test_session_manager.py @@ -0,0 +1,45 @@ +import secrets +from datetime import timedelta + +from facefusion.session_manager import clear_session, create_session, get_session, set_session, validate_session + + +def test_get_and_set_session() -> None: + session = create_session() + access_token = secrets.token_urlsafe(128) + + set_session(access_token, session) + + assert get_session(access_token) == session + + +def test_validate_session() -> None: + session = create_session() + access_token = secrets.token_urlsafe(128) + + set_session(access_token, session) + + assert validate_session(access_token) is True + + set_session(access_token, + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token'), + 'created_at': session.get('created_at'), + 'expires_at': session.get('expires_at') - timedelta(hours = 1) + }) + + assert validate_session(access_token) is False + + +def test_clear_session() -> None: + session = create_session() + access_token = secrets.token_urlsafe(128) + + set_session(access_token, session) + + assert validate_session(access_token) is True + + clear_session(access_token) + + assert validate_session(access_token) is None