diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 829cc9be..8d1f9d86 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,7 @@ jobs: python-version: '3.12' - run: python install.py --onnxruntime default --skip-conda - run: pip install pytest + - run: pip install httpx - run: pytest report: needs: test @@ -52,6 +53,7 @@ jobs: - run: pip install coveralls - run: pip install pytest - run: pip install pytest-cov + - run: pip install httpx - run: pytest tests --cov facefusion - run: coveralls --service github env: diff --git a/README.md b/README.md index 2942ff05..67503790 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ commands: batch-run run the program in batch mode force-download force automate downloads and exit benchmark benchmark the program + api start the API server job-list list jobs by status job-create create a drafted job job-submit submit a drafted job to become a queued job diff --git a/facefusion.ini b/facefusion.ini index 80fb530c..cccf3529 100644 --- a/facefusion.ini +++ b/facefusion.ini @@ -113,6 +113,10 @@ benchmark_mode = benchmark_resolutions = benchmark_cycle_count = +[api] +api_host = +api_port = + [execution] execution_device_ids = execution_providers = diff --git a/facefusion/apis/__init__.py b/facefusion/apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py new file mode 100644 index 00000000..67a08502 --- /dev/null +++ b/facefusion/apis/core.py @@ -0,0 +1,30 @@ +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +from starlette.routing import Route + +from facefusion.apis.session import create_session +from facefusion.apis.session import create_session_guard +from facefusion.apis.session import destroy_session +from facefusion.apis.session import get_session +from facefusion.apis.session import refresh_session +from facefusion.apis.state import get_state +from facefusion.apis.state import set_state + + +def create_api() -> Starlette: + session_guard = Middleware(create_session_guard) + routes =\ + [ + Route('/session', create_session, methods = [ 'POST' ]), + Route('/session', get_session, methods = [ 'GET' ], middleware = [ session_guard ]), + Route('/session', refresh_session, methods = [ 'PUT' ]), + Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]), + Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]), + Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]) + ] + + api = Starlette(routes = routes) + api.add_middleware(CORSMiddleware, allow_origins = [ '*' ], allow_methods = [ '*' ], allow_headers = [ '*' ]) + + return api diff --git a/facefusion/apis/locals.py b/facefusion/apis/locals.py new file mode 100644 index 00000000..97be9e00 --- /dev/null +++ b/facefusion/apis/locals.py @@ -0,0 +1,12 @@ +from facefusion.types import Locals + +LOCALS : Locals =\ +{ + 'en': + { + 'ok': 'ok', + 'something_went_wrong': 'something went wrong', + 'invalid_access_token': 'invalid access token', + 'invalid_refresh_token': 'invalid refresh token' + } +} diff --git a/facefusion/apis/session.py b/facefusion/apis/session.py new file mode 100644 index 00000000..d18946dc --- /dev/null +++ b/facefusion/apis/session.py @@ -0,0 +1,131 @@ +import os +from typing import Optional + +from starlette.datastructures import Headers +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED, HTTP_426_UPGRADE_REQUIRED +from starlette.types import ASGIApp, Receive, Scope, Send + +from facefusion import session_manager, translator +from facefusion.types import Token + + +async def create_session(request : Request) -> JSONResponse: + body = await request.json() + + if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'): + session = session_manager.create_session() + session_manager.set_session(session.get('access_token'), session) + + return JSONResponse( + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token') + }, status_code = HTTP_201_CREATED) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +async def get_session(request : Request) -> JSONResponse: + access_token = extract_access_token(request.headers) + + if access_token: + session = session_manager.get_session(access_token) + + if session: + return JSONResponse( + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token'), + 'created_at': session.get('created_at').isoformat(), + 'expires_at': session.get('expires_at').isoformat() + }, status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +async def refresh_session(request : Request) -> JSONResponse: + body = await request.json() + + for access_token, session in session_manager.SESSIONS.items(): + if session.get('refresh_token') == body.get('refresh_token'): + session_manager.clear_session(access_token) + session = session_manager.create_session() + session_manager.set_session(session.get('access_token'), session) + + return JSONResponse( + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token') + }, status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +async def destroy_session(request : Request) -> JSONResponse: + access_token = extract_access_token(request.headers) + + if access_token: + session = session_manager.get_session(access_token) + + if session: + session_manager.clear_session(access_token) + return JSONResponse( + { + 'message': translator.get('ok', __package__) + }, status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('something_went_wrong', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + +def create_session_guard(app : ASGIApp) -> ASGIApp: + async def middleware(scope : Scope, receive : Receive, send : Send) -> None: + access_token = extract_access_token(Headers(scope = scope)) + + if access_token and session_manager.validate_session(access_token): + return await app(scope, receive, send) + + if access_token: + session = session_manager.get_session(access_token) + + if session: + response = JSONResponse( + { + 'message': translator.get('invalid_access_token', __package__) + }, status_code = HTTP_426_UPGRADE_REQUIRED) + + return await response(scope, receive, send) + + response = JSONResponse( + { + 'message': translator.get('invalid_access_token', __package__) + }, status_code = HTTP_401_UNAUTHORIZED) + + return await response(scope, receive, send) + + return middleware + + +def extract_access_token(headers : Headers) -> Optional[Token]: + auth_header = headers.get('Authorization') + + if auth_header: + auth_prefix, _, access_token = auth_header.partition(' ') + + if auth_prefix.lower() == 'bearer' and access_token: + return access_token + + return None diff --git a/facefusion/apis/state.py b/facefusion/apis/state.py new file mode 100644 index 00000000..682240c9 --- /dev/null +++ b/facefusion/apis/state.py @@ -0,0 +1,23 @@ +from typing import get_args + +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_200_OK + +from facefusion import state_manager +from facefusion.types import StateKey + + +async def get_state(request : Request) -> JSONResponse: + return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK) + + +async def set_state(request : Request) -> JSONResponse: + body = await request.json() + + for key, value in body.items(): + if key in get_args(StateKey): + state_manager.set_item(key, value) + + return JSONResponse(state_manager.get_state(), status_code = HTTP_200_OK) + diff --git a/facefusion/args.py b/facefusion/args.py index 6d64ecac..095995ee 100644 --- a/facefusion/args.py +++ b/facefusion/args.py @@ -7,81 +7,6 @@ from facefusion.types import ApplyStateItem, Args from facefusion.vision import detect_video_fps -def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: - apply_state_item('command', args.get('command')) - apply_state_item('temp_path', args.get('temp_path')) - apply_state_item('jobs_path', args.get('jobs_path')) - apply_state_item('source_paths', args.get('source_paths')) - apply_state_item('target_path', args.get('target_path')) - apply_state_item('output_path', args.get('output_path')) - apply_state_item('source_pattern', args.get('source_pattern')) - apply_state_item('target_pattern', args.get('target_pattern')) - apply_state_item('output_pattern', args.get('output_pattern')) - apply_state_item('face_detector_model', args.get('face_detector_model')) - apply_state_item('face_detector_size', args.get('face_detector_size')) - apply_state_item('face_detector_margin', normalize_space(args.get('face_detector_margin'))) - apply_state_item('face_detector_angles', args.get('face_detector_angles')) - apply_state_item('face_detector_score', args.get('face_detector_score')) - apply_state_item('face_landmarker_model', args.get('face_landmarker_model')) - apply_state_item('face_landmarker_score', args.get('face_landmarker_score')) - apply_state_item('face_selector_mode', args.get('face_selector_mode')) - apply_state_item('face_selector_order', args.get('face_selector_order')) - apply_state_item('face_selector_age_start', args.get('face_selector_age_start')) - apply_state_item('face_selector_age_end', args.get('face_selector_age_end')) - apply_state_item('face_selector_gender', args.get('face_selector_gender')) - apply_state_item('face_selector_race', args.get('face_selector_race')) - apply_state_item('reference_face_position', args.get('reference_face_position')) - apply_state_item('reference_face_distance', args.get('reference_face_distance')) - apply_state_item('reference_frame_number', args.get('reference_frame_number')) - apply_state_item('face_occluder_model', args.get('face_occluder_model')) - apply_state_item('face_parser_model', args.get('face_parser_model')) - apply_state_item('face_mask_types', args.get('face_mask_types')) - apply_state_item('face_mask_areas', args.get('face_mask_areas')) - apply_state_item('face_mask_regions', args.get('face_mask_regions')) - apply_state_item('face_mask_blur', args.get('face_mask_blur')) - apply_state_item('face_mask_padding', normalize_space(args.get('face_mask_padding'))) - apply_state_item('voice_extractor_model', args.get('voice_extractor_model')) - apply_state_item('trim_frame_start', args.get('trim_frame_start')) - apply_state_item('trim_frame_end', args.get('trim_frame_end')) - apply_state_item('temp_frame_format', args.get('temp_frame_format')) - apply_state_item('keep_temp', args.get('keep_temp')) - apply_state_item('output_image_quality', args.get('output_image_quality')) - apply_state_item('output_image_scale', args.get('output_image_scale')) - apply_state_item('output_audio_encoder', args.get('output_audio_encoder')) - apply_state_item('output_audio_quality', args.get('output_audio_quality')) - apply_state_item('output_audio_volume', args.get('output_audio_volume')) - apply_state_item('output_video_encoder', args.get('output_video_encoder')) - apply_state_item('output_video_preset', args.get('output_video_preset')) - apply_state_item('output_video_quality', args.get('output_video_quality')) - apply_state_item('output_video_scale', args.get('output_video_scale')) - - if args.get('output_video_fps') or is_video(args.get('target_path')): - output_video_fps = normalize_fps(args.get('output_video_fps')) or detect_video_fps(args.get('target_path')) - apply_state_item('output_video_fps', output_video_fps) - - available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] - apply_state_item('processors', args.get('processors')) - - for processor_module in get_processors_modules(available_processors): - processor_module.apply_args(args, apply_state_item) - # execution - apply_state_item('execution_device_ids', args.get('execution_device_ids')) - apply_state_item('execution_providers', args.get('execution_providers')) - apply_state_item('execution_thread_count', args.get('execution_thread_count')) - apply_state_item('download_providers', args.get('download_providers')) - apply_state_item('download_scope', args.get('download_scope')) - apply_state_item('benchmark_mode', args.get('benchmark_mode')) - apply_state_item('benchmark_resolutions', args.get('benchmark_resolutions')) - apply_state_item('benchmark_cycle_count', args.get('benchmark_cycle_count')) - apply_state_item('video_memory_strategy', args.get('video_memory_strategy')) - apply_state_item('system_memory_limit', args.get('system_memory_limit')) - apply_state_item('log_level', args.get('log_level')) - apply_state_item('halt_on_error', args.get('halt_on_error')) - apply_state_item('job_id', args.get('job_id')) - apply_state_item('job_status', args.get('job_status')) - apply_state_item('step_index', args.get('step_index')) - - def reduce_step_args(args : Args) -> Args: step_args =\ { @@ -112,3 +37,94 @@ def collect_job_args() -> Args: key: state_manager.get_item(key) for key in job_store.get_job_keys() #type:ignore[arg-type] } return job_args + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + # general + apply_state_item('command', args.get('command')) + # paths + apply_state_item('temp_path', args.get('temp_path')) + apply_state_item('jobs_path', args.get('jobs_path')) + apply_state_item('source_paths', args.get('source_paths')) + apply_state_item('target_path', args.get('target_path')) + apply_state_item('output_path', args.get('output_path')) + # patterns + apply_state_item('source_pattern', args.get('source_pattern')) + apply_state_item('target_pattern', args.get('target_pattern')) + apply_state_item('output_pattern', args.get('output_pattern')) + # face detector + apply_state_item('face_detector_model', args.get('face_detector_model')) + apply_state_item('face_detector_size', args.get('face_detector_size')) + apply_state_item('face_detector_margin', normalize_space(args.get('face_detector_margin'))) + apply_state_item('face_detector_angles', args.get('face_detector_angles')) + apply_state_item('face_detector_score', args.get('face_detector_score')) + # face landmarker + apply_state_item('face_landmarker_model', args.get('face_landmarker_model')) + apply_state_item('face_landmarker_score', args.get('face_landmarker_score')) + # face selector + apply_state_item('face_selector_mode', args.get('face_selector_mode')) + apply_state_item('face_selector_order', args.get('face_selector_order')) + apply_state_item('face_selector_age_start', args.get('face_selector_age_start')) + apply_state_item('face_selector_age_end', args.get('face_selector_age_end')) + apply_state_item('face_selector_gender', args.get('face_selector_gender')) + apply_state_item('face_selector_race', args.get('face_selector_race')) + apply_state_item('reference_face_position', args.get('reference_face_position')) + apply_state_item('reference_face_distance', args.get('reference_face_distance')) + apply_state_item('reference_frame_number', args.get('reference_frame_number')) + # face masker + apply_state_item('face_occluder_model', args.get('face_occluder_model')) + apply_state_item('face_parser_model', args.get('face_parser_model')) + apply_state_item('face_mask_types', args.get('face_mask_types')) + apply_state_item('face_mask_areas', args.get('face_mask_areas')) + apply_state_item('face_mask_regions', args.get('face_mask_regions')) + apply_state_item('face_mask_blur', args.get('face_mask_blur')) + apply_state_item('face_mask_padding', normalize_space(args.get('face_mask_padding'))) + # voice extractor + apply_state_item('voice_extractor_model', args.get('voice_extractor_model')) + # frame extraction + apply_state_item('trim_frame_start', args.get('trim_frame_start')) + apply_state_item('trim_frame_end', args.get('trim_frame_end')) + apply_state_item('temp_frame_format', args.get('temp_frame_format')) + apply_state_item('keep_temp', args.get('keep_temp')) + # output creation + apply_state_item('output_image_quality', args.get('output_image_quality')) + apply_state_item('output_image_scale', args.get('output_image_scale')) + apply_state_item('output_audio_encoder', args.get('output_audio_encoder')) + apply_state_item('output_audio_quality', args.get('output_audio_quality')) + apply_state_item('output_audio_volume', args.get('output_audio_volume')) + apply_state_item('output_video_encoder', args.get('output_video_encoder')) + apply_state_item('output_video_preset', args.get('output_video_preset')) + apply_state_item('output_video_quality', args.get('output_video_quality')) + apply_state_item('output_video_scale', args.get('output_video_scale')) + if args.get('output_video_fps') or is_video(args.get('target_path')): + output_video_fps = normalize_fps(args.get('output_video_fps')) or detect_video_fps(args.get('target_path')) + apply_state_item('output_video_fps', output_video_fps) + # processors + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + apply_state_item('processors', args.get('processors')) + for processor_module in get_processors_modules(available_processors): + processor_module.apply_args(args, apply_state_item) + # execution + apply_state_item('execution_device_ids', args.get('execution_device_ids')) + apply_state_item('execution_providers', args.get('execution_providers')) + apply_state_item('execution_thread_count', args.get('execution_thread_count')) + # download + apply_state_item('download_providers', args.get('download_providers')) + apply_state_item('download_scope', args.get('download_scope')) + # benchmark + apply_state_item('benchmark_mode', args.get('benchmark_mode')) + apply_state_item('benchmark_resolutions', args.get('benchmark_resolutions')) + apply_state_item('benchmark_cycle_count', args.get('benchmark_cycle_count')) + # api + apply_state_item('api_host', args.get('api_host')) + apply_state_item('api_port', args.get('api_port')) + # memory + apply_state_item('video_memory_strategy', args.get('video_memory_strategy')) + apply_state_item('system_memory_limit', args.get('system_memory_limit')) + # misc + apply_state_item('log_level', args.get('log_level')) + apply_state_item('halt_on_error', args.get('halt_on_error')) + # jobs + apply_state_item('job_id', args.get('job_id')) + apply_state_item('job_status', args.get('job_status')) + apply_state_item('step_index', args.get('step_index')) diff --git a/facefusion/core.py b/facefusion/core.py index 228768cd..a8b5a127 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -5,7 +5,10 @@ import signal import sys from time import time +import uvicorn + from facefusion import benchmarker, cli_helper, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, hash_helper, logger, state_manager, translator, voice_extractor +from facefusion.apis.core import create_api from facefusion.args import apply_args, collect_job_args, reduce_job_args, reduce_step_args from facefusion.download import conditional_download_hashes, conditional_download_sources from facefusion.exit_helper import hard_exit, signal_exit @@ -55,6 +58,11 @@ def route(args : Args) -> None: hard_exit(2) benchmarker.render() + if state_manager.get_item('command') == 'api': + logger.info(translator.get('api_started').format(host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port')), __name__) + uvicorn.run(create_api(), host = state_manager.get_item('api_host'), port = state_manager.get_item('api_port')) + hard_exit(1) + if state_manager.get_item('command') in [ 'job-list', 'job-create', 'job-submit', 'job-submit-all', 'job-delete', 'job-delete-all', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step' ]: if not job_manager.init_jobs(state_manager.get_item('jobs_path')): hard_exit(1) diff --git a/facefusion/locales.py b/facefusion/locales.py index a6c97713..659c8abf 100644 --- a/facefusion/locales.py +++ b/facefusion/locales.py @@ -48,10 +48,9 @@ LOCALES : Locales =\ 'no_source_face_detected': 'no source face detected', 'processor_not_loaded': 'processor {processor} could not be loaded', 'processor_not_implemented': 'processor {processor} not implemented correctly', - 'ui_layout_not_loaded': 'ui layout {ui_layout} could not be loaded', - 'ui_layout_not_implemented': 'ui layout {ui_layout} not implemented correctly', 'stream_not_loaded': 'stream {stream_mode} could not be loaded', 'stream_not_supported': 'stream not supported', + 'api_started': 'started API on {host}:{port}', 'job_created': 'job {job_id} created', 'job_not_created': 'job {job_id} not created', 'job_submitted': 'job {job_id} submitted', @@ -154,6 +153,8 @@ LOCALES : Locales =\ 'benchmark_mode': 'choose the benchmark mode', 'benchmark_resolutions': 'choose the resolutions for the benchmarks (choices: {choices}, ...)', 'benchmark_cycle_count': 'specify the amount of cycles per benchmark', + 'api_host': 'specify the API host', + 'api_port': 'specify the API port', 'execution_device_ids': 'specify the devices used for processing', 'execution_providers': 'inference using different providers (choices: {choices}, ...)', 'execution_thread_count': 'specify the amount of parallel threads while processing', @@ -165,6 +166,7 @@ LOCALES : Locales =\ 'batch_run': 'run the program in batch mode', 'force_download': 'force automate downloads and exit', 'benchmark': 'benchmark the program', + 'api': 'start the API server', 'job_id': 'specify the job id', 'job_status': 'specify the job status', 'step_index': 'specify the step index', diff --git a/facefusion/program.py b/facefusion/program.py index 06b1c644..cec30c1c 100755 --- a/facefusion/program.py +++ b/facefusion/program.py @@ -220,6 +220,14 @@ def create_benchmark_program() -> ArgumentParser: return program +def create_api_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_api = program.add_argument_group('api') + group_api.add_argument('--api-host', help = translator.get('help.api_host'), default = config.get_str_value('api', 'api_host', '127.0.0.1')) + group_api.add_argument('--api-port', help = translator.get('help.api_port'), type = int, default = config.get_int_value('api', 'api_port', '8000')) + return program + + def create_execution_program() -> ArgumentParser: program = ArgumentParser(add_help = False) available_execution_providers = get_available_execution_providers() @@ -291,6 +299,7 @@ def create_program() -> ArgumentParser: sub_program.add_parser('batch-run', help = translator.get('help.batch_run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_pattern_program(), create_target_pattern_program(), create_output_pattern_program(), collect_step_program(), collect_job_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('force-download', help = translator.get('help.force_download'), parents = [ create_download_providers_program(), create_download_scope_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('benchmark', help = translator.get('help.benchmark'), parents = [ create_temp_path_program(), collect_step_program(), create_benchmark_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('api', help = translator.get('help.api'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_api_program(), collect_job_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('job-list', help = translator.get('help.job_list'), parents = [ create_job_status_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('job-create', help = translator.get('help.job_create'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) sub_program.add_parser('job-submit', help = translator.get('help.job_submit'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) diff --git a/facefusion/session_manager.py b/facefusion/session_manager.py new file mode 100644 index 00000000..8a27abca --- /dev/null +++ b/facefusion/session_manager.py @@ -0,0 +1,39 @@ +import secrets +from datetime import datetime, timedelta +from typing import Dict +from typing import Optional + +from facefusion.types import Session, Token + +SESSIONS : Dict[Token, Session] = {} + + +def create_session() -> Session: + session : Session =\ + { + 'access_token': secrets.token_urlsafe(128), + 'refresh_token': secrets.token_urlsafe(128), + 'created_at': datetime.now(), + 'expires_at': datetime.now() + timedelta(minutes = 10) + } + + return session + + +def get_session(access_token : Token) -> Optional[Session]: + return SESSIONS.get(access_token) + + +def set_session(access_token : Token, session : Session) -> None: + SESSIONS[access_token] = session + + +def validate_session(access_token : Token) -> bool: + session = get_session(access_token) + return session and datetime.now() <= session.get('expires_at') + + +def clear_session(access_token : Token) -> None: + if access_token in SESSIONS: + del SESSIONS[access_token] + diff --git a/facefusion/types.py b/facefusion/types.py index 9d393a13..d55be4ff 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -1,4 +1,5 @@ from collections import namedtuple +from datetime import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeAlias, TypedDict import cv2 @@ -101,6 +102,15 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool] Content : TypeAlias = Dict[str, Any] +Token : TypeAlias = str +Session = TypedDict('Session', +{ + 'access_token': Token, + 'refresh_token': Token, + 'created_at': datetime, + 'expires_at': datetime +}) + Command : TypeAlias = str CommandSet : TypeAlias = Dict[str, List[Command]] diff --git a/requirements.txt b/requirements.txt index 957cebc2..26369c79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ opencv-python==4.13.0.92 psutil==7.1.3 tqdm==4.67.3 scipy==1.17.1 +starlette==0.50.0 +uvicorn==0.34.0 diff --git a/tests/test_api_session.py b/tests/test_api_session.py new file mode 100644 index 00000000..29f83eb9 --- /dev/null +++ b/tests/test_api_session.py @@ -0,0 +1,154 @@ +import os +from datetime import timedelta +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from facefusion import metadata, session_manager +from facefusion.apis.core import create_api +from facefusion.types import Session + + +@pytest.fixture(scope = 'module') +def test_client() -> Iterator[TestClient]: + with TestClient(create_api()) as test_client: + yield test_client + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + session_manager.SESSIONS.clear() + + +def test_create_session(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + assert session_manager.get_session(create_session_body.get('access_token')) + assert create_session_response.status_code == 201 + + create_session_response = test_client.post('/session', json = + { + 'api_key': 'TEST', + 'client_version': metadata.get('version') + }) + + assert create_session_response.status_code == 401 + + os.environ['FACEFUSION_API_KEY'] = 'TEST' + create_session_response = test_client.post('/session', json = + { + 'api_key': 'INVALID', + 'client_version': metadata.get('version') + }) + + assert create_session_response.status_code == 401 + + os.environ['FACEFUSION_API_KEY'] = 'TEST' + create_session_response = test_client.post('/session', json = + { + 'api_key': 'TEST', + 'client_version': metadata.get('version') + }) + + assert create_session_response.status_code == 201 + + del os.environ['FACEFUSION_API_KEY'] + + +def test_get_session(test_client : TestClient) -> None: + get_session_response = test_client.get('/session') + + assert get_session_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + get_session_response = test_client.get('/session', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert get_session_response.status_code == 200 + + access_token = create_session_body.get('access_token') + session : Session = session_manager.get_session(access_token) + session_manager.set_session(access_token, + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token'), + 'created_at': session.get('created_at'), + 'expires_at': session.get('expires_at') - timedelta(hours = 1) + }) + + get_session_response = test_client.get('/session', headers = + { + 'Authorization': 'Bearer ' + access_token + }) + + assert get_session_response.status_code == 426 + + +def test_refresh_session(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + refresh_session_response = test_client.put('/session', json = + { + 'refresh_token': 'INVALID' + }) + + assert refresh_session_response.status_code == 401 + + refresh_session_response = test_client.put('/session', json = + { + 'refresh_token': create_session_body.get('refresh_token') + }) + refresh_session_body = refresh_session_response.json() + + assert session_manager.get_session(create_session_body.get('access_token')) is None + + assert session_manager.get_session(refresh_session_body.get('access_token')) + + assert refresh_session_response.status_code == 200 + + refresh_session_response = test_client.put('/session', json = + { + 'refresh_token': create_session_body.get('refresh_token') + }) + + assert refresh_session_response.status_code == 401 + + +def test_destroy_session(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + delete_session_response = test_client.delete('/session', headers = + { + 'Authorization': 'Bearer INVALID' + }) + + assert delete_session_response.status_code == 401 + + delete_session_response = test_client.delete('/session', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert session_manager.get_session(create_session_body.get('access_token')) is None + + assert delete_session_response.status_code == 200 diff --git a/tests/test_api_state.py b/tests/test_api_state.py new file mode 100644 index 00000000..d15797c1 --- /dev/null +++ b/tests/test_api_state.py @@ -0,0 +1,80 @@ +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from facefusion import metadata, session_manager, state_manager +from facefusion.apis.core import create_api + + +@pytest.fixture(scope = 'module') +def test_client() -> Iterator[TestClient]: + state_manager.init_item('execution_providers', [ 'cpu' ]) + + with TestClient(create_api()) as test_client: + yield test_client + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + session_manager.SESSIONS.clear() + + +def test_get_state(test_client : TestClient) -> None: + get_state_response = test_client.get('/state') + + assert get_state_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + get_state_response = test_client.get('/state', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + get_state_body = get_state_response.json() + + assert get_state_body.get('execution_providers') == [ 'cpu' ] + assert get_state_response.status_code == 200 + + +def test_set_state(test_client : TestClient) -> None: + set_state_response = test_client.put('/state', json = + { + 'execution_providers': [ 'cuda' ] + }) + + assert set_state_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + set_state_response = test_client.put('/state', json = + { + 'execution_providers': [ 'cuda' ] + }, headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + set_state_body = set_state_response.json() + + assert set_state_body.get('execution_providers') == [ 'cuda' ] + assert set_state_response.status_code == 200 + + set_state_response = test_client.put('/state', json = + { + 'invalid': 'invalid' + }, headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + set_state_body = set_state_response.json() + + assert set_state_body.get('invalid') is None + assert set_state_response.status_code == 200 diff --git a/tests/test_job_manager.py b/tests/test_job_manager.py index 05b1fedc..27164146 100644 --- a/tests/test_job_manager.py +++ b/tests/test_job_manager.py @@ -106,7 +106,7 @@ def test_find_jobs() -> None: assert 'job-test-find-jobs-1' in find_jobs('drafted') assert 'job-test-find-jobs-2' in find_jobs('drafted') - assert not find_jobs('queued') + assert find_jobs('queued') == {} move_job_file('job-test-find-jobs-1', 'queued') diff --git a/tests/test_json.py b/tests/test_json.py index 6b7cfb1a..29d9b48b 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -8,7 +8,7 @@ def test_read_json() -> None: file_descriptor, json_path = tempfile.mkstemp(suffix = '.json') os.close(file_descriptor) - assert not read_json(json_path) + assert read_json(json_path) is None write_json(json_path, {}) diff --git a/tests/test_session_manager.py b/tests/test_session_manager.py new file mode 100644 index 00000000..c886c541 --- /dev/null +++ b/tests/test_session_manager.py @@ -0,0 +1,45 @@ +import secrets +from datetime import timedelta + +from facefusion.session_manager import clear_session, create_session, get_session, set_session, validate_session + + +def test_get_and_set_session() -> None: + session = create_session() + access_token = secrets.token_urlsafe(128) + + set_session(access_token, session) + + assert get_session(access_token) == session + + +def test_validate_session() -> None: + session = create_session() + access_token = secrets.token_urlsafe(128) + + set_session(access_token, session) + + assert validate_session(access_token) is True + + set_session(access_token, + { + 'access_token': session.get('access_token'), + 'refresh_token': session.get('refresh_token'), + 'created_at': session.get('created_at'), + 'expires_at': session.get('expires_at') - timedelta(hours = 1) + }) + + assert validate_session(access_token) is False + + +def test_clear_session() -> None: + session = create_session() + access_token = secrets.token_urlsafe(128) + + set_session(access_token, session) + + assert validate_session(access_token) is True + + clear_session(access_token) + + assert validate_session(access_token) is None