mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 21:07:50 +02:00
Better args types part2 (#1051)
* remove helper methods finally * state becomes the total truth now * state becomes the total truth now * state becomes the total truth now * state becomes the total truth now * add ini file
This commit is contained in:
@@ -2,13 +2,13 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import capability_store, session_manager, state_manager, translator
|
||||
from facefusion import args_helper, capability_store, session_manager, state_manager, translator
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.endpoints.session import extract_access_token
|
||||
|
||||
|
||||
async def get_state(request : Request) -> JSONResponse:
|
||||
api_args = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
api_args = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ async def set_state(request : Request) -> JSONResponse:
|
||||
if key in api_args:
|
||||
state_manager.set_item(key, value)
|
||||
|
||||
__api_args__ = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) #type:ignore[arg-type]
|
||||
__api_args__ = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def select_source(request : Request) -> JSONResponse:
|
||||
@@ -50,7 +50,7 @@ async def select_source(request : Request) -> JSONResponse:
|
||||
|
||||
state_manager.set_item('source_paths', source_paths)
|
||||
|
||||
__api_args__ = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
__api_args__ = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -71,7 +71,7 @@ async def select_target(request : Request) -> JSONResponse:
|
||||
if asset:
|
||||
state_manager.set_item('target_path', asset.get('path'))
|
||||
|
||||
__api_args__ = capability_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
__api_args__ = args_helper.extract_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import Union
|
||||
|
||||
from facefusion.capability_store import get_api_arguments, get_cli_arguments, get_sys_arguments
|
||||
from facefusion.filesystem import get_file_name, is_video, resolve_file_paths
|
||||
from facefusion.normalizer import normalize_fps, normalize_space
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
from facefusion.processors.types import ProcessorState
|
||||
from facefusion.types import ApplyStateItem, Args, State
|
||||
from facefusion.vision import detect_video_fps
|
||||
|
||||
@@ -82,7 +85,7 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
|
||||
apply_state_item('step_index', args.get('step_index'))
|
||||
|
||||
|
||||
def extract_api_args(state : State) -> Args:
|
||||
def extract_api_args(state : Union[State, ProcessorState]) -> Args:
|
||||
api_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_api_arguments()
|
||||
@@ -90,7 +93,7 @@ def extract_api_args(state : State) -> Args:
|
||||
return api_args
|
||||
|
||||
|
||||
def extract_cli_args(state : State) -> Args:
|
||||
def extract_cli_args(state : Union[State, ProcessorState]) -> Args:
|
||||
cli_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_cli_arguments()
|
||||
@@ -98,7 +101,7 @@ def extract_cli_args(state : State) -> Args:
|
||||
return cli_args
|
||||
|
||||
|
||||
def extract_sys_args(state : State) -> Args:
|
||||
def extract_sys_args(state : Union[State, ProcessorState]) -> Args:
|
||||
sys_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_sys_arguments()
|
||||
@@ -106,9 +109,17 @@ def extract_sys_args(state : State) -> Args:
|
||||
return sys_args
|
||||
|
||||
|
||||
def extract_step_args(state : State) -> Args:
|
||||
def extract_step_args(state : Union[State, ProcessorState]) -> Args:
|
||||
step_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_cli_arguments() and key not in get_sys_arguments()
|
||||
}
|
||||
return step_args
|
||||
|
||||
|
||||
def filter_step_args(args : Args) -> Args:
|
||||
step_args =\
|
||||
{
|
||||
key: args.get(key) for key in args if key in get_cli_arguments() and key not in get_sys_arguments()
|
||||
}
|
||||
return step_args
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from argparse import Action
|
||||
from typing import Dict, List
|
||||
|
||||
from facefusion.types import Args, CapabilitySet, CapabilityStore, Scope, State
|
||||
from facefusion.types import CapabilitySet, CapabilityStore, Scope
|
||||
|
||||
CAPABILITY_STORE : CapabilityStore =\
|
||||
{
|
||||
@@ -52,35 +52,3 @@ def register_capability_set(actions : List[Action], scopes : List[Scope]) -> Non
|
||||
CAPABILITY_STORE['cli'][action.dest] = value
|
||||
if scope == 'sys':
|
||||
CAPABILITY_STORE['sys'][action.dest] = value
|
||||
|
||||
|
||||
def filter_api_args(state : State) -> Args:
|
||||
api_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_api_arguments()
|
||||
}
|
||||
return api_args
|
||||
|
||||
|
||||
def filter_cli_args(state : State) -> Args:
|
||||
cli_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_cli_arguments()
|
||||
}
|
||||
return cli_args
|
||||
|
||||
|
||||
def filter_step_args(args : Args) -> Args:
|
||||
step_args =\
|
||||
{
|
||||
key: args.get(key) for key in args if key in get_cli_arguments() and key not in get_sys_arguments()
|
||||
}
|
||||
return step_args
|
||||
|
||||
|
||||
def filter_sys_args(state : State) -> Args:
|
||||
sys_args =\
|
||||
{
|
||||
key: state.get(key) for key in state if key in get_sys_arguments()
|
||||
}
|
||||
return sys_args
|
||||
|
||||
+5
-5
@@ -198,7 +198,7 @@ def route_job_manager(args : Args) -> ErrorCode:
|
||||
return 1
|
||||
|
||||
if state_manager.get_item('command') == 'job-add-step':
|
||||
step_args = args_helper.extract_step_args(args) # type:ignore[arg-type]
|
||||
step_args = args_helper.filter_step_args(args)
|
||||
|
||||
if job_manager.add_step(state_manager.get_item('job_id'), step_args):
|
||||
logger.info(translator.get('job_step_added').format(job_id = state_manager.get_item('job_id')), __name__)
|
||||
@@ -207,7 +207,7 @@ def route_job_manager(args : Args) -> ErrorCode:
|
||||
return 1
|
||||
|
||||
if state_manager.get_item('command') == 'job-remix-step':
|
||||
step_args = args_helper.extract_step_args(args) # type:ignore[arg-type]
|
||||
step_args = args_helper.filter_step_args(args)
|
||||
|
||||
if job_manager.remix_step(state_manager.get_item('job_id'), state_manager.get_item('step_index'), step_args):
|
||||
logger.info(translator.get('job_remix_step_added').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__)
|
||||
@@ -216,7 +216,7 @@ def route_job_manager(args : Args) -> ErrorCode:
|
||||
return 1
|
||||
|
||||
if state_manager.get_item('command') == 'job-insert-step':
|
||||
step_args = args_helper.extract_step_args(args) # type:ignore[arg-type]
|
||||
step_args = args_helper.filter_step_args(args)
|
||||
|
||||
if job_manager.insert_step(state_manager.get_item('job_id'), state_manager.get_item('step_index'), step_args):
|
||||
logger.info(translator.get('job_step_inserted').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__)
|
||||
@@ -270,7 +270,7 @@ def route_job_runner() -> ErrorCode:
|
||||
|
||||
def process_headless(args : Args) -> ErrorCode:
|
||||
job_id = job_helper.suggest_job_id('headless')
|
||||
step_args = args_helper.extract_step_args(args) # type:ignore[arg-type]
|
||||
step_args = args_helper.filter_step_args(args)
|
||||
|
||||
if job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step):
|
||||
return 0
|
||||
@@ -279,7 +279,7 @@ def process_headless(args : Args) -> ErrorCode:
|
||||
|
||||
def process_batch(args : Args) -> ErrorCode:
|
||||
job_id = job_helper.suggest_job_id('batch')
|
||||
step_args = args_helper.extract_step_args(args) # type:ignore[arg-type]
|
||||
step_args = args_helper.filter_step_args(args)
|
||||
source_paths = resolve_file_pattern(step_args.get('source_pattern'))
|
||||
target_paths = resolve_file_pattern(step_args.get('target_pattern'))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user