diff --git a/facefusion/apis/asset_store.py b/facefusion/apis/asset_store.py index 72d8b4bc..be20dbfd 100644 --- a/facefusion/apis/asset_store.py +++ b/facefusion/apis/asset_store.py @@ -1,11 +1,11 @@ import os import uuid from datetime import datetime, timedelta -from typing import Optional, cast +from typing import List, Optional, cast from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata from facefusion.filesystem import get_file_format, get_file_name -from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat +from facefusion.types import AssetId, AssetSet, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat ASSET_STORE : AssetStore = {} @@ -70,9 +70,21 @@ def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : st return ASSET_STORE[session_id].get(asset_id) +def get_assets(session_id : SessionId) -> Optional[AssetSet]: + return ASSET_STORE.get(session_id) + + def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]: if session_id in ASSET_STORE: - return ASSET_STORE[session_id].get(asset_id) + return ASSET_STORE.get(session_id).get(asset_id) + return None + + +def delete_assets(session_id : SessionId, asset_ids : List[AssetId]) -> None: + if session_id in ASSET_STORE: + for asset_id in asset_ids: + if asset_id in ASSET_STORE.get(session_id): + del ASSET_STORE[session_id][asset_id] return None diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 8692c4bb..9ecd5744 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -3,7 +3,7 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route, WebSocketRoute -from facefusion.apis.endpoints.assets import upload_asset +from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state @@ -19,7 +19,10 @@ def create_api() -> Starlette: Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]), Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]), Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]), + Route('/assets', get_assets, methods = [ 'GET' ], middleware = [ session_guard ]), Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]), + Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]), + Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]) ] diff --git a/facefusion/apis/endpoints/assets.py b/facefusion/apis/endpoints/assets.py index 55687699..057dedf3 100644 --- a/facefusion/apis/endpoints/assets.py +++ b/facefusion/apis/endpoints/assets.py @@ -5,7 +5,7 @@ from typing import List from starlette.datastructures import UploadFile from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST +from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND from facefusion import ffmpeg, process_manager, session_manager, state_manager from facefusion.apis import asset_store @@ -22,7 +22,7 @@ async def upload_asset(request : Request) -> Response: if session_id and asset_type in [ 'source', 'target' ]: form = await request.form() upload_files = form.getlist('file') - asset_paths = await save_asset_files(upload_files) + asset_paths = await save_asset_files(upload_files) # type: ignore[arg-type] if asset_paths: asset_ids : List[str] = [] @@ -76,3 +76,78 @@ async def save_asset_files(upload_files : List[UploadFile]) -> List[str]: remove_file(temp_file.name) return asset_paths + + +async def get_assets(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + + if session_id: + asset_set = asset_store.get_assets(session_id) + assets = [] + + if asset_set: + for asset in asset_set.values(): + assets.append( + { + 'id': asset.get('id'), + 'created_at': asset.get('created_at').isoformat(), + 'expires_at': asset.get('expires_at').isoformat(), + 'type': asset.get('type'), + 'media': asset.get('media'), + 'name': asset.get('name'), + 'format': asset.get('format'), + 'size': asset.get('size'), + 'metadata': asset.get('metadata') + }) + + return JSONResponse( + { + 'assets': assets + }, status_code = HTTP_200_OK) + + return Response(status_code = HTTP_400_BAD_REQUEST) + + +async def get_asset(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + asset_id = request.path_params.get('asset_id') + + if session_id and asset_id: + asset = asset_store.get_asset(session_id, asset_id) + + if asset: + return JSONResponse( + { + 'id': asset.get('id'), + 'created_at': asset.get('created_at').isoformat(), + 'expires_at': asset.get('expires_at').isoformat(), + 'type': asset.get('type'), + 'media': asset.get('media'), + 'name': asset.get('name'), + 'format': asset.get('format'), + 'size': asset.get('size'), + 'metadata': asset.get('metadata') + }, status_code = HTTP_200_OK) + + return Response(status_code = HTTP_404_NOT_FOUND) + + +async def delete_assets(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + body = await request.json() + asset_ids = body.get('asset_ids') + + if session_id and asset_ids: + asset_set = asset_store.get_assets(session_id) + + if asset_set: + for asset_id in asset_ids: + if asset_id in asset_set: + remove_file(asset_set.get(asset_id).get('path')) + asset_store.delete_assets(session_id, asset_ids) + return Response(status_code = HTTP_200_OK) + + return Response(status_code = HTTP_404_NOT_FOUND) diff --git a/facefusion/types.py b/facefusion/types.py index d3cc82ce..cbedcdcc 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -232,7 +232,8 @@ VideoAsset = TypedDict('VideoAsset', }) AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata -AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]] +AssetSet : TypeAlias = Dict[AssetId, AudioAsset | ImageAsset | VideoAsset] +AssetStore : TypeAlias = Dict[SessionId, AssetSet] BenchmarkMode = Literal['warm', 'cold'] BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p'] diff --git a/tests/test_api_assets.py b/tests/test_api_assets.py index 5b5fe259..0d7d6442 100644 --- a/tests/test_api_assets.py +++ b/tests/test_api_assets.py @@ -4,7 +4,7 @@ from typing import Iterator import pytest from starlette.testclient import TestClient -from facefusion import metadata, process_manager, session_manager, state_manager +from facefusion import metadata, session_manager, state_manager from facefusion.apis import asset_store from facefusion.apis.core import create_api from facefusion.download import conditional_download @@ -101,3 +101,172 @@ def test_upload_asset(test_client : TestClient) -> None: }) assert upload_response.status_code == 400 + + +def test_get_assets(test_client : TestClient) -> None: + get_response = test_client.get('/assets') + + assert get_response.status_code == 401 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + access_token = create_session_body.get('access_token') + + get_response = test_client.get('/assets', headers = + { + 'Authorization': 'Bearer ' + access_token + }) + get_body = get_response.json() + + assert get_body.get('assets') == [] + + assert get_response.status_code == 200 + + source_path = get_test_example_file('source.jpg') + target_path = get_test_example_file('target-240p.mp4') + + with open(source_path, 'rb') as source_file, open(target_path, 'rb') as target_file: + source_content = source_file.read() + target_content = target_file.read() + test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + access_token + }, files = + [ + ('file', ('source.jpg', source_content, 'image/jpeg')), + ('file', ('target.mp4', target_content, 'video/mp4')) + ]) + + get_response = test_client.get('/assets', headers = + { + 'Authorization': 'Bearer ' + access_token + }) + get_body = get_response.json() + assets = get_body.get('assets') + + assert len(assets) == 2 + assert assets[0].get('media') == 'image' + assert assets[1].get('media') == 'video' + + assert get_response.status_code == 200 + + +def test_get_asset(test_client : TestClient) -> None: + get_response = test_client.get('invalid') + + assert get_response.status_code == 404 + + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + access_token = create_session_body.get('access_token') + + source_path = get_test_example_file('source.jpg') + + with open(source_path, 'rb') as source_file: + source_content = source_file.read() + upload_response = test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + access_token + }, files = + [ + ('file', ('source.jpg', source_content, 'image/jpeg')) + ]) + upload_body = upload_response.json() + asset_id = upload_body.get('asset_ids')[0] + + second_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + second_session_body = second_session_response.json() + second_access_token = second_session_body.get('access_token') + + get_response = test_client.get('/assets/' + asset_id, headers = + { + 'Authorization': 'Bearer ' + second_access_token + }) + + assert get_response.status_code == 404 + + get_response = test_client.get('/assets/' + asset_id, headers = + { + 'Authorization': 'Bearer ' + access_token + }) + get_body = get_response.json() + + assert get_body.get('id') == asset_id + assert get_body.get('type') == 'source' + assert get_body.get('media') == 'image' + assert get_body.get('format') == 'jpeg' + assert get_body.get('metadata').get('resolution') == [ 1024, 1024 ] + + assert get_response.status_code == 200 + + +def test_delete_assets(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + access_token = create_session_body.get('access_token') + session_id = session_manager.find_session_id(access_token) + + source_path = get_test_example_file('source.jpg') + + with open(source_path, 'rb') as source_file: + source_content = source_file.read() + upload_response = test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + access_token + }, files = + [ + ('file', ('source.jpg', source_content, 'image/jpeg')) + ]) + upload_body = upload_response.json() + asset_id = upload_body.get('asset_ids')[0] + + assert asset_store.get_asset(session_id, asset_id) + + second_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + second_session_body = second_session_response.json() + second_access_token = second_session_body.get('access_token') + + delete_response = test_client.request('DELETE', '/assets', headers = + { + 'Authorization': 'Bearer ' + second_access_token + }, json = + { + 'asset_ids': [ asset_id ] + }) + + assert delete_response.status_code == 404 + + delete_response = test_client.request('DELETE', '/assets', headers = + { + 'Authorization': 'Bearer ' + access_token + }, json = + { + 'asset_ids': [ asset_id ] + }) + + assert delete_response.status_code == 200 + + delete_response = test_client.request('DELETE', '/assets', headers = + { + 'Authorization': 'Bearer ' + access_token + }, json = + { + 'asset_ids': [ asset_id ] + }) + + assert delete_response.status_code == 404