diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index ac092826..b6cb07e1 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -4,7 +4,7 @@ from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route, WebSocketRoute from facefusion.apis.choices import get_choices -from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset +from facefusion.apis.endpoints.assets import delete_asset, delete_assets, get_asset, get_assets, upload_asset from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state @@ -29,6 +29,7 @@ def create_api() -> Starlette: Route('/assets', get_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]), Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), + Route('/assets/{asset_id}', delete_asset, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]), Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]), Route('/choices', get_choices, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), Route('/remote', remote, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]), diff --git a/facefusion/apis/endpoints/assets.py b/facefusion/apis/endpoints/assets.py index 73af75e6..9043dd35 100644 --- a/facefusion/apis/endpoints/assets.py +++ b/facefusion/apis/endpoints/assets.py @@ -1,5 +1,5 @@ import tempfile -from typing import List +from typing import Any, Dict, List, Optional from starlette.datastructures import UploadFile from starlette.requests import Request @@ -11,6 +11,22 @@ from facefusion.apis import asset_store from facefusion.apis.asset_helper import detect_media_type from facefusion.apis.endpoints.session import extract_access_token from facefusion.filesystem import get_file_extension, remove_file +from facefusion.types import AudioAsset, ImageAsset, VideoAsset + + +def translate_asset(asset : AudioAsset | ImageAsset | VideoAsset) -> Optional[Dict[str, Any]]: + return\ + { + 'id': asset.get('id'), + 'created_at': asset.get('created_at').isoformat(), + 'expires_at': asset.get('expires_at').isoformat(), + 'type': asset.get('type'), + 'media_type': asset.get('media'), + 'filename': asset.get('name'), + 'format': asset.get('format'), + 'size': asset.get('size'), + 'metadata': asset.get('metadata') + } async def upload_asset(request : Request) -> Response: @@ -34,6 +50,12 @@ async def upload_asset(request : Request) -> Response: asset_ids.append(asset_id) if asset_ids: + if asset_type == 'target': + return JSONResponse( + { + 'asset_id': asset_ids[0] + }, status_code = HTTP_201_CREATED) + return JSONResponse( { 'asset_ids': asset_ids @@ -68,6 +90,7 @@ async def save_asset_files(upload_files : List[UploadFile]) -> List[str]: async def get_assets(request : Request) -> Response: access_token = extract_access_token(request.scope) session_id = session_manager.find_session_id(access_token) + asset_type = request.query_params.get('type') if session_id: asset_set = asset_store.get_assets(session_id) @@ -75,22 +98,13 @@ async def get_assets(request : Request) -> Response: if asset_set: for asset in asset_set.values(): - assets.append( - { - 'id': asset.get('id'), - 'created_at': asset.get('created_at').isoformat(), - 'expires_at': asset.get('expires_at').isoformat(), - 'type': asset.get('type'), - 'media': asset.get('media'), - 'name': asset.get('name'), - 'format': asset.get('format'), - 'size': asset.get('size'), - 'metadata': asset.get('metadata') - }) + if not asset_type or asset.get('type') == asset_type: + assets.append(translate_asset(asset)) return JSONResponse( { - 'assets': assets + 'assets': assets, + 'count': len(assets) }, status_code = HTTP_200_OK) return Response(status_code = HTTP_400_BAD_REQUEST) @@ -107,23 +121,25 @@ async def get_asset(request : Request) -> Response: if asset: if action == 'download': - asset_path = asset.get('path') - asset_name = asset.get('name') + return FileResponse(asset.get('path'), filename = asset.get('name')) - return FileResponse(asset_path, filename = asset_name) + return JSONResponse(translate_asset(asset), status_code = HTTP_200_OK) - return JSONResponse( - { - 'id': asset.get('id'), - 'created_at': asset.get('created_at').isoformat(), - 'expires_at': asset.get('expires_at').isoformat(), - 'type': asset.get('type'), - 'media': asset.get('media'), - 'name': asset.get('name'), - 'format': asset.get('format'), - 'size': asset.get('size'), - 'metadata': asset.get('metadata') - }, status_code = HTTP_200_OK) + return Response(status_code = HTTP_404_NOT_FOUND) + + +async def delete_asset(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + asset_id = request.path_params.get('asset_id') + + if session_id and asset_id: + asset_set = asset_store.get_assets(session_id) + + if asset_set and asset_id in asset_set: + remove_file(asset_set.get(asset_id).get('path')) + asset_store.delete_assets(session_id, [ asset_id ]) + return Response(status_code = HTTP_200_OK) return Response(status_code = HTTP_404_NOT_FOUND)