diff --git a/facefusion/apis/endpoints/__init__.py b/facefusion/apis/endpoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/facefusion/apis/endpoints/state.py b/facefusion/apis/endpoints/state.py index 383c4ef2..797689c4 100644 --- a/facefusion/apis/endpoints/state.py +++ b/facefusion/apis/endpoints/state.py @@ -2,13 +2,13 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND -from facefusion import capability_store, session_manager, state_manager, translator +from facefusion import args_helper, capability_store, session_manager, state_manager, translator from facefusion.apis import asset_store from facefusion.apis.endpoints.session import extract_access_token async def get_state(request : Request) -> JSONResponse: - api_args = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] + api_args = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type] return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK) @@ -29,8 +29,8 @@ async def set_state(request : Request) -> JSONResponse: if key in api_args: state_manager.set_item(key, value) - __api_args__ = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] - return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) #type:ignore[arg-type] + __api_args__ = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type] + return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) async def select_source(request : Request) -> JSONResponse: @@ -50,7 +50,7 @@ async def select_source(request : Request) -> JSONResponse: state_manager.set_item('source_paths', source_paths) - __api_args__ = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] + __api_args__ = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type] return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) return JSONResponse( @@ -71,7 +71,7 @@ async def select_target(request : Request) -> JSONResponse: if asset: state_manager.set_item('target_path', asset.get('path')) - __api_args__ = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] + __api_args__ = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type] return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) return JSONResponse( diff --git a/facefusion/args_helper.py b/facefusion/args_helper.py index 83ef2287..ff0a3988 100644 --- a/facefusion/args_helper.py +++ b/facefusion/args_helper.py @@ -1,7 +1,10 @@ +from typing import Union + from facefusion.capability_store import get_api_arguments, get_cli_arguments, get_sys_arguments from facefusion.filesystem import get_file_name, is_video, resolve_file_paths from facefusion.normalizer import normalize_fps, normalize_space from facefusion.processors.core import get_processors_modules +from facefusion.processors.types import ProcessorState from facefusion.types import ApplyStateItem, Args, State from facefusion.vision import detect_video_fps @@ -82,7 +85,7 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: apply_state_item('step_index', args.get('step_index')) -def extract_api_args(state : State) -> Args: +def extract_api_args(state : Union[State, ProcessorState]) -> Args: api_args =\ { key: state.get(key) for key in state if key in get_api_arguments() @@ -90,7 +93,7 @@ def extract_api_args(state : State) -> Args: return api_args -def extract_cli_args(state : State) -> Args: +def extract_cli_args(state : Union[State, ProcessorState]) -> Args: cli_args =\ { key: state.get(key) for key in state if key in get_cli_arguments() @@ -98,7 +101,7 @@ def extract_cli_args(state : State) -> Args: return cli_args -def extract_sys_args(state : State) -> Args: +def extract_sys_args(state : Union[State, ProcessorState]) -> Args: sys_args =\ { key: state.get(key) for key in state if key in get_sys_arguments() @@ -106,9 +109,17 @@ def extract_sys_args(state : State) -> Args: return sys_args -def extract_step_args(state : State) -> Args: +def extract_step_args(state : Union[State, ProcessorState]) -> Args: step_args =\ { key: state.get(key) for key in state if key in get_cli_arguments() and key not in get_sys_arguments() } return step_args + + +def filter_step_args(args : Args) -> Args: + step_args =\ + { + key: args.get(key) for key in args if key in get_cli_arguments() and key not in get_sys_arguments() + } + return step_args diff --git a/facefusion/capability_store.py b/facefusion/capability_store.py index 2d6b4b2a..83ef0dc2 100644 --- a/facefusion/capability_store.py +++ b/facefusion/capability_store.py @@ -1,7 +1,7 @@ from argparse import Action from typing import Dict, List -from facefusion.types import Args, CapabilitySet, CapabilityStore, Scope, State +from facefusion.types import CapabilitySet, CapabilityStore, Scope CAPABILITY_STORE : CapabilityStore =\ { @@ -52,35 +52,3 @@ def register_capability_set(actions : List[Action], scopes : List[Scope]) -> Non CAPABILITY_STORE['cli'][action.dest] = value if scope == 'sys': CAPABILITY_STORE['sys'][action.dest] = value - - -def filter_api_args(state : State) -> Args: - api_args =\ - { - key: state.get(key) for key in state if key in get_api_arguments() - } - return api_args - - -def filter_cli_args(state : State) -> Args: - cli_args =\ - { - key: state.get(key) for key in state if key in get_cli_arguments() - } - return cli_args - - -def filter_step_args(args : Args) -> Args: - step_args =\ - { - key: args.get(key) for key in args if key in get_cli_arguments() and key not in get_sys_arguments() - } - return step_args - - -def filter_sys_args(state : State) -> Args: - sys_args =\ - { - key: state.get(key) for key in state if key in get_sys_arguments() - } - return sys_args diff --git a/facefusion/core.py b/facefusion/core.py index 8692d511..54aaee27 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -198,7 +198,7 @@ def route_job_manager(args : Args) -> ErrorCode: return 1 if state_manager.get_item('command') == 'job-add-step': - step_args = args_helper.extract_step_args(args) # type:ignore[arg-type] + step_args = args_helper.filter_step_args(args) if job_manager.add_step(state_manager.get_item('job_id'), step_args): logger.info(translator.get('job_step_added').format(job_id = state_manager.get_item('job_id')), __name__) @@ -207,7 +207,7 @@ def route_job_manager(args : Args) -> ErrorCode: return 1 if state_manager.get_item('command') == 'job-remix-step': - step_args = args_helper.extract_step_args(args) # type:ignore[arg-type] + step_args = args_helper.filter_step_args(args) if job_manager.remix_step(state_manager.get_item('job_id'), state_manager.get_item('step_index'), step_args): logger.info(translator.get('job_remix_step_added').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) @@ -216,7 +216,7 @@ def route_job_manager(args : Args) -> ErrorCode: return 1 if state_manager.get_item('command') == 'job-insert-step': - step_args = args_helper.extract_step_args(args) # type:ignore[arg-type] + step_args = args_helper.filter_step_args(args) if job_manager.insert_step(state_manager.get_item('job_id'), state_manager.get_item('step_index'), step_args): logger.info(translator.get('job_step_inserted').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) @@ -270,7 +270,7 @@ def route_job_runner() -> ErrorCode: def process_headless(args : Args) -> ErrorCode: job_id = job_helper.suggest_job_id('headless') - step_args = args_helper.extract_step_args(args) # type:ignore[arg-type] + step_args = args_helper.filter_step_args(args) if job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step): return 0 @@ -279,7 +279,7 @@ def process_headless(args : Args) -> ErrorCode: def process_batch(args : Args) -> ErrorCode: job_id = job_helper.suggest_job_id('batch') - step_args = args_helper.extract_step_args(args) # type:ignore[arg-type] + step_args = args_helper.filter_step_args(args) source_paths = resolve_file_pattern(step_args.get('source_pattern')) target_paths = resolve_file_pattern(step_args.get('target_pattern'))