diff --git a/facefusion/apis/endpoints/state.py b/facefusion/apis/endpoints/state.py index b8414475..322b16b6 100644 --- a/facefusion/apis/endpoints/state.py +++ b/facefusion/apis/endpoints/state.py @@ -1,6 +1,6 @@ from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND +from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_CONTENT from facefusion import args_helper, capability_store, session_manager, state_manager, translator from facefusion.apis import asset_store @@ -13,6 +13,8 @@ async def get_state(request : Request) -> JSONResponse: async def set_state(request : Request) -> JSONResponse: + __api_args__ = {} + action = request.query_params.get('action') asset_type = request.query_params.get('type') @@ -26,11 +28,19 @@ async def set_state(request : Request) -> JSONResponse: api_args = capability_store.get_api_arguments() for key, value in body.items(): - if key in api_args: - state_manager.set_item(key, value) + if key not in api_args: + return JSONResponse( + { + 'message': translator.get('invalid_state_key', 'facefusion.apis') + }, status_code = HTTP_400_BAD_REQUEST) + __api_args__[key] = value + state_manager.set_item(key, value) - __api_args__ = args_helper.extract_api_args(state_manager.get_state()) - return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) + if __api_args__: + __api_args__ = args_helper.extract_api_args(state_manager.get_state()) + return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) + + return JSONResponse({}, status_code = HTTP_422_UNPROCESSABLE_CONTENT) async def select_source(request : Request) -> JSONResponse: diff --git a/facefusion/apis/locales.py b/facefusion/apis/locales.py index 1f1c6ceb..831be2f4 100644 --- a/facefusion/apis/locales.py +++ b/facefusion/apis/locales.py @@ -9,6 +9,7 @@ LOCALES : Locales =\ 'invalid_access_token': 'invalid access token', 'invalid_refresh_token': 'invalid refresh token', 'source_asset_not_found': 'source asset not found', - 'target_asset_not_found': 'target asset not found' + 'target_asset_not_found': 'target asset not found', + 'invalid_state_key': 'invalid state key' } } diff --git a/tests/test_api_state.py b/tests/test_api_state.py index 3b907036..4210ba1e 100644 --- a/tests/test_api_state.py +++ b/tests/test_api_state.py @@ -120,7 +120,27 @@ def test_set_state(test_client : TestClient) -> None: set_state_body = set_state_response.json() assert set_state_body.get('invalid') is None - assert set_state_response.status_code == 200 + assert set_state_response.status_code == 400 + + set_state_response = test_client.put('/state', json = + { + 'execution_providers': [ 'cuda' ], + 'invalid': 'invalid' + }, headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + set_state_body = set_state_response.json() + + assert set_state_body.get('invalid') is None + assert set_state_response.status_code == 400 + + set_state_response = test_client.put('/state', json = {}, headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert set_state_response.status_code == 422 def test_select_source_assets(test_client : TestClient) -> None: