From b7b60c186fefc9662ae5827f471d52e207339f94 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 22 Jan 2026 16:05:32 +0530 Subject: [PATCH] refactor api --- facefusion/apis/asset_helper.py | 50 ++++ facefusion/apis/asset_store.py | 92 +++++++ facefusion/apis/assets.py | 186 -------------- facefusion/apis/core.py | 16 +- facefusion/apis/endpoints/__init__.py | 0 facefusion/apis/endpoints/assets.py | 147 +++++++++++ facefusion/apis/{ => endpoints}/ping.py | 0 facefusion/apis/{ => endpoints}/session.py | 17 +- facefusion/apis/endpoints/state.py | 80 ++++++ facefusion/apis/locales.py | 14 ++ facefusion/apis/remote.py | 41 ++-- facefusion/apis/state.py | 76 ------ facefusion/apis/timeline.py | 9 +- facefusion/types.py | 64 +++++ tests/test_asset_store.py | 268 +++++++++------------ 15 files changed, 596 insertions(+), 464 deletions(-) create mode 100644 facefusion/apis/asset_helper.py create mode 100644 facefusion/apis/asset_store.py delete mode 100644 facefusion/apis/assets.py create mode 100644 facefusion/apis/endpoints/__init__.py create mode 100644 facefusion/apis/endpoints/assets.py rename facefusion/apis/{ => endpoints}/ping.py (100%) rename facefusion/apis/{ => endpoints}/session.py (87%) create mode 100644 facefusion/apis/endpoints/state.py create mode 100644 facefusion/apis/locales.py delete mode 100644 facefusion/apis/state.py diff --git a/facefusion/apis/asset_helper.py b/facefusion/apis/asset_helper.py new file mode 100644 index 00000000..808092cc --- /dev/null +++ b/facefusion/apis/asset_helper.py @@ -0,0 +1,50 @@ +from typing import Optional + +from facefusion.audio import detect_audio_duration +from facefusion.filesystem import is_audio, is_image, is_video +from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata +from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution + + +def extract_audio_metadata(file_path : str) -> AudioMetadata: + metadata : AudioMetadata =\ + { + 'duration': detect_audio_duration(file_path), + 'sample_rate': 0, + 'frame_total': 0, + 'channels': 0, + 'format': '' + } + return metadata + + +def extract_image_metadata(file_path : str) -> ImageMetadata: + resolution = detect_image_resolution(file_path) + metadata : ImageMetadata =\ + { + 'resolution': resolution if resolution else (0, 0) + } + return metadata + + +def extract_video_metadata(file_path : str) -> VideoMetadata: + resolution = detect_video_resolution(file_path) + fps = detect_video_fps(file_path) + metadata : VideoMetadata =\ + { + 'duration': detect_video_duration(file_path), + 'frame_total': count_video_frame_total(file_path), + 'fps': fps if fps else 0.0, + 'resolution': resolution if resolution else (0, 0) + } + return metadata + + +def detect_media_type(file_path : str) -> Optional[MediaType]: + if is_audio(file_path): + return 'audio' + if is_image(file_path): + return 'image' + if is_video(file_path): + return 'video' + return None diff --git a/facefusion/apis/asset_store.py b/facefusion/apis/asset_store.py new file mode 100644 index 00000000..be20dbfd --- /dev/null +++ b/facefusion/apis/asset_store.py @@ -0,0 +1,92 @@ +import os +import uuid +from datetime import datetime, timedelta +from typing import List, Optional, cast + +from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata +from facefusion.filesystem import get_file_format, get_file_name +from facefusion.types import AssetId, AssetSet, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat + +ASSET_STORE : AssetStore = {} + + +def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]: + asset_id = str(uuid.uuid4()) + asset_name = get_file_name(asset_path) + asset_format = get_file_format(asset_path) + asset_size = os.path.getsize(asset_path) + media_type = detect_media_type(asset_path) + created_at = datetime.now() + expires_at = created_at + timedelta(hours = 2) + + if session_id not in ASSET_STORE: + ASSET_STORE[session_id] = {} + + if media_type == 'audio': + ASSET_STORE[session_id][asset_id] = cast(AudioAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': asset_name, + 'format': cast(AudioFormat, asset_format), + 'size': asset_size, + 'path': asset_path, + 'metadata': extract_audio_metadata(asset_path) + }) + + if media_type == 'image': + ASSET_STORE[session_id][asset_id] = cast(ImageAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': asset_name, + 'format': cast(ImageFormat, asset_format), + 'size': asset_size, + 'path': asset_path, + 'metadata': extract_image_metadata(asset_path) + }) + + if media_type == 'video': + ASSET_STORE[session_id][asset_id] = cast(VideoAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': asset_name, + 'format': cast(VideoFormat, asset_format), + 'size': asset_size, + 'path': asset_path, + 'metadata': extract_video_metadata(asset_path) + }) + + return ASSET_STORE[session_id].get(asset_id) + + +def get_assets(session_id : SessionId) -> Optional[AssetSet]: + return ASSET_STORE.get(session_id) + + +def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]: + if session_id in ASSET_STORE: + return ASSET_STORE.get(session_id).get(asset_id) + return None + + +def delete_assets(session_id : SessionId, asset_ids : List[AssetId]) -> None: + if session_id in ASSET_STORE: + for asset_id in asset_ids: + if asset_id in ASSET_STORE.get(session_id): + del ASSET_STORE[session_id][asset_id] + return None + + +def clear() -> None: + ASSET_STORE.clear() diff --git a/facefusion/apis/assets.py b/facefusion/apis/assets.py deleted file mode 100644 index 926656b5..00000000 --- a/facefusion/apis/assets.py +++ /dev/null @@ -1,186 +0,0 @@ -import os -import tempfile - -from starlette.requests import Request -from starlette.responses import FileResponse, JSONResponse -from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND - -from facefusion import asset_store, filesystem, logger -from facefusion.vision import count_video_frame_total, detect_video_fps, detect_video_resolution - - -async def upload_asset(request : Request) -> JSONResponse: - asset_type = request.query_params.get('type') - - if not asset_type: - return JSONResponse({'message': 'Missing required query parameter: type'}, status_code = HTTP_400_BAD_REQUEST) - - if asset_type not in ['source', 'target']: - return JSONResponse({'message': 'Invalid type. Must be "source" or "target"'}, status_code = HTTP_400_BAD_REQUEST) - - form = await request.form() - - if asset_type == 'source': - files = form.getlist('file') - - if not files: - return JSONResponse({'message': 'No file provided'}, status_code = HTTP_400_BAD_REQUEST) - - asset_ids = [] - - for file in files: - filename = file.filename if hasattr(file, 'filename') else 'source.jpg' - file_extension = os.path.splitext(filename)[1] if filename else '.jpg' - - with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as temp_file: - content = await file.read() - temp_file.write(content) - file_path = temp_file.name - - if not (filesystem.is_image(file_path) or filesystem.is_video(file_path) or filesystem.is_audio(file_path)): - if os.path.exists(file_path): - os.remove(file_path) - return JSONResponse( - { - 'message': 'Unsupported file format. Allowed formats - Images: bmp, jpeg, png, tiff, webp. Videos: avi, m4v, mkv, mov, mp4, mpeg, mxf, webm, wmv.' - }, - status_code = HTTP_400_BAD_REQUEST - ) - - asset_id = asset_store.register('source', file_path, filename) - asset_ids.append(asset_id) - - logger.debug(f'Uploaded {len(asset_ids)} source(s)', __name__) - - return JSONResponse( - { - 'message': f'{len(asset_ids)} source(s) uploaded successfully', - 'asset_ids': asset_ids - }, - status_code = HTTP_201_CREATED - ) - - if asset_type == 'target': - file = form.get('file') - - if not file: - return JSONResponse({'message': 'No file provided'}, status_code = HTTP_400_BAD_REQUEST) - - if isinstance(file, str): - return JSONResponse({'message': 'Expected file upload, got string. Use /stream/target for URLs'}, status_code = HTTP_400_BAD_REQUEST) - - if not hasattr(file, 'filename'): - return JSONResponse({'message': 'Invalid file object'}, status_code = HTTP_400_BAD_REQUEST) - - filename = file.filename - file_extension = os.path.splitext(filename)[1] if filename else '.jpg' - - with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as temp_file: - content = await file.read() - temp_file.write(content) - file_path = temp_file.name - - if not (filesystem.is_image(file_path) or filesystem.is_video(file_path) or filesystem.is_audio(file_path)): - if os.path.exists(file_path): - os.remove(file_path) - return JSONResponse( - { - 'message': 'Unsupported file format. Allowed formats - Images: bmp, jpeg, png, tiff, webp. Videos: avi, m4v, mkv, mov, mp4, mpeg, mxf, webm, wmv.' - }, - status_code = HTTP_400_BAD_REQUEST - ) - - metadata = None - - if filesystem.is_video(file_path): - frame_total = count_video_frame_total(file_path) - fps = detect_video_fps(file_path) - resolution = detect_video_resolution(file_path) - metadata =\ - { - 'frame_total': frame_total, - 'fps': fps, - 'resolution': resolution - } - logger.debug(f'Video metadata - frames: {frame_total}, fps: {fps}, resolution: {resolution}', __name__) - - asset_id = asset_store.register('target', file_path, filename, metadata) - - logger.debug(f'Target uploaded with asset_id: {asset_id}', __name__) - - return JSONResponse( - { - 'message': 'Target uploaded successfully', - 'asset_id': asset_id - }, - status_code = HTTP_201_CREATED - ) - - -async def list_all_assets(request : Request) -> JSONResponse: - asset_type = request.query_params.get('type') - media_type = request.query_params.get('media_type') - format = request.query_params.get('format') - - assets = asset_store.list_assets(asset_type) - - if media_type: - assets = [a for a in assets if a.get('media_type') == media_type] - - if format: - assets = [a for a in assets if a.get('format') == format] - - safe_assets = [] - for asset in assets: - safe_asset = {k: v for k, v in asset.items() if k != 'path'} - safe_assets.append(safe_asset) - - return JSONResponse({'assets': safe_assets, 'count': len(safe_assets)}, status_code = HTTP_200_OK) - - -async def get_asset_by_id(request : Request) -> JSONResponse | FileResponse: - from facefusion.session_context import get_session_id - - asset_id = request.path_params.get('asset_id') - action = request.query_params.get('action') - asset = asset_store.get_asset(asset_id) - - if not asset: - return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND) - - if asset.get('session_id') != get_session_id(): - return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND) - - if action == 'download': - file_path = asset.get('path') - - if not file_path or not os.path.exists(file_path): - return JSONResponse({'message': 'Asset file not found'}, status_code = HTTP_404_NOT_FOUND) - - filename = asset.get('filename', 'download') - - return FileResponse(file_path, filename = filename) - - safe_asset = {k: v for k, v in asset.items() if k != 'path'} - - return JSONResponse(safe_asset, status_code = HTTP_200_OK) - - -async def delete_asset_by_id(request : Request) -> JSONResponse: - from facefusion.session_context import get_session_id - - asset_id = request.path_params.get('asset_id') - asset = asset_store.get_asset(asset_id) - - if not asset: - return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND) - - if asset.get('session_id') != get_session_id(): - return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND) - - success = asset_store.delete_asset(asset_id) - - if not success: - return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND) - - return JSONResponse({'message': 'Asset deleted successfully'}, status_code = HTTP_200_OK) diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 37407268..ac092826 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -3,14 +3,14 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route, WebSocketRoute -from facefusion.apis.assets import delete_asset_by_id, get_asset_by_id, list_all_assets, upload_asset from facefusion.apis.choices import get_choices +from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset +from facefusion.apis.endpoints.ping import websocket_ping +from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session +from facefusion.apis.endpoints.state import get_state, set_state from facefusion.apis.metrics import websocket_metrics -from facefusion.apis.ping import websocket_ping from facefusion.apis.process import webrtc_offer, webrtc_stream_offer, websocket_process from facefusion.apis.remote import remote -from facefusion.apis.session import create_session, create_session_guard, destroy_session, get_session, refresh_session -from facefusion.apis.state import get_state, set_state from facefusion.apis.timeline import get_timeline from facefusion.apis.version import create_version_guard @@ -26,11 +26,11 @@ def create_api() -> Starlette: Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]), Route('/state', get_state, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), Route('/state', set_state, methods = [ 'PUT' ], middleware = [ version_guard, session_guard ]), + Route('/assets', get_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]), - Route('/assets', list_all_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), - Route('/assets/{asset_id}', get_asset_by_id, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), - Route('/assets/{asset_id}', delete_asset_by_id, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]), - Route('/choices', get_choices, methods=['GET'], middleware=[ version_guard, session_guard ]), + Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), + Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]), + Route('/choices', get_choices, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), Route('/remote', remote, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]), Route('/timeline/{count:int}', get_timeline, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]), Route('/webrtc/offer', webrtc_offer, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]), diff --git a/facefusion/apis/endpoints/__init__.py b/facefusion/apis/endpoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/facefusion/apis/endpoints/assets.py b/facefusion/apis/endpoints/assets.py new file mode 100644 index 00000000..73af75e6 --- /dev/null +++ b/facefusion/apis/endpoints/assets.py @@ -0,0 +1,147 @@ +import tempfile +from typing import List + +from starlette.datastructures import UploadFile +from starlette.requests import Request +from starlette.responses import FileResponse, JSONResponse, Response +from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND + +from facefusion import session_manager +from facefusion.apis import asset_store +from facefusion.apis.asset_helper import detect_media_type +from facefusion.apis.endpoints.session import extract_access_token +from facefusion.filesystem import get_file_extension, remove_file + + +async def upload_asset(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + asset_type = request.query_params.get('type') + + if session_id and asset_type in [ 'source', 'target' ]: + form = await request.form() + upload_files = form.getlist('file') + asset_paths = await save_asset_files(upload_files) # type: ignore[arg-type] + + if asset_paths: + asset_ids : List[str] = [] + + for asset_path in asset_paths: + asset = asset_store.create_asset(session_id, asset_type, asset_path) # type: ignore[arg-type] + asset_id = asset.get('id') + + if asset_id: + asset_ids.append(asset_id) + + if asset_ids: + return JSONResponse( + { + 'asset_ids': asset_ids + }, status_code = HTTP_201_CREATED) + + return Response(status_code = HTTP_400_BAD_REQUEST) + + +async def save_asset_files(upload_files : List[UploadFile]) -> List[str]: + asset_paths : List[str] = [] + + for upload_file in upload_files: + upload_file_extension = get_file_extension(upload_file.filename) + + with tempfile.NamedTemporaryFile(suffix = upload_file_extension, delete = False) as temp_file: + + while upload_chunk := await upload_file.read(1024): + temp_file.write(upload_chunk) + + temp_file.flush() + + media_type = detect_media_type(temp_file.name) + + if media_type: + asset_paths.append(temp_file.name) + else: + remove_file(temp_file.name) + + return asset_paths + + +async def get_assets(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + + if session_id: + asset_set = asset_store.get_assets(session_id) + assets = [] + + if asset_set: + for asset in asset_set.values(): + assets.append( + { + 'id': asset.get('id'), + 'created_at': asset.get('created_at').isoformat(), + 'expires_at': asset.get('expires_at').isoformat(), + 'type': asset.get('type'), + 'media': asset.get('media'), + 'name': asset.get('name'), + 'format': asset.get('format'), + 'size': asset.get('size'), + 'metadata': asset.get('metadata') + }) + + return JSONResponse( + { + 'assets': assets + }, status_code = HTTP_200_OK) + + return Response(status_code = HTTP_400_BAD_REQUEST) + + +async def get_asset(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + asset_id = request.path_params.get('asset_id') + action = request.query_params.get('action') + + if session_id and asset_id: + asset = asset_store.get_asset(session_id, asset_id) + + if asset: + if action == 'download': + asset_path = asset.get('path') + asset_name = asset.get('name') + + return FileResponse(asset_path, filename = asset_name) + + return JSONResponse( + { + 'id': asset.get('id'), + 'created_at': asset.get('created_at').isoformat(), + 'expires_at': asset.get('expires_at').isoformat(), + 'type': asset.get('type'), + 'media': asset.get('media'), + 'name': asset.get('name'), + 'format': asset.get('format'), + 'size': asset.get('size'), + 'metadata': asset.get('metadata') + }, status_code = HTTP_200_OK) + + return Response(status_code = HTTP_404_NOT_FOUND) + + +async def delete_assets(request : Request) -> Response: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + body = await request.json() + asset_ids = body.get('asset_ids') + + if session_id and asset_ids: + asset_set = asset_store.get_assets(session_id) + + if asset_set: + for asset_id in asset_ids: + if asset_id in asset_set: + remove_file(asset_set.get(asset_id).get('path')) + asset_store.delete_assets(session_id, asset_ids) + return Response(status_code = HTTP_200_OK) + + return Response(status_code = HTTP_404_NOT_FOUND) diff --git a/facefusion/apis/ping.py b/facefusion/apis/endpoints/ping.py similarity index 100% rename from facefusion/apis/ping.py rename to facefusion/apis/endpoints/ping.py diff --git a/facefusion/apis/session.py b/facefusion/apis/endpoints/session.py similarity index 87% rename from facefusion/apis/session.py rename to facefusion/apis/endpoints/session.py index 2668e408..f26466ed 100644 --- a/facefusion/apis/session.py +++ b/facefusion/apis/endpoints/session.py @@ -30,7 +30,7 @@ async def create_session(request : Request) -> JSONResponse: return JSONResponse( { - 'message': translator.get('something_went_wrong', __package__) + 'message': translator.get('something_went_wrong', 'facefusion.apis') }, status_code = HTTP_401_UNAUTHORIZED) @@ -53,7 +53,7 @@ async def get_session(request : Request) -> JSONResponse: return JSONResponse( { - 'message': translator.get('something_went_wrong', __package__) + 'message': translator.get('something_went_wrong', 'facefusion.apis') }, status_code = HTTP_401_UNAUTHORIZED) @@ -73,7 +73,7 @@ async def refresh_session(request : Request) -> JSONResponse: return JSONResponse( { - 'message': translator.get('something_went_wrong', __package__) + 'message': translator.get('something_went_wrong', 'facefusion.apis') }, status_code = HTTP_401_UNAUTHORIZED) @@ -88,12 +88,12 @@ async def destroy_session(request : Request) -> JSONResponse: return JSONResponse( { - 'message': translator.get('ok', __package__) + 'message': translator.get('ok', 'facefusion.apis') }, status_code = HTTP_200_OK) return JSONResponse( { - 'message': translator.get('something_went_wrong', __package__) + 'message': translator.get('something_went_wrong', 'facefusion.apis') }, status_code = HTTP_401_UNAUTHORIZED) @@ -106,20 +106,19 @@ def create_session_guard(app : ASGIApp) -> ASGIApp: if session_id: if session_manager.validate_session(session_id): - from facefusion.session_context import set_session_id - set_session_id(session_id) + session_context.set_session_id(session_id) return await app(scope, receive, send) response = JSONResponse( { - 'message': translator.get('invalid_access_token', __package__) + 'message': translator.get('invalid_access_token', 'facefusion.apis') }, status_code = HTTP_426_UPGRADE_REQUIRED) return await response(scope, receive, send) response = JSONResponse( { - 'message': translator.get('invalid_access_token', __package__) + 'message': translator.get('invalid_access_token', 'facefusion.apis') }, status_code = HTTP_401_UNAUTHORIZED) return await response(scope, receive, send) diff --git a/facefusion/apis/endpoints/state.py b/facefusion/apis/endpoints/state.py new file mode 100644 index 00000000..334cc71e --- /dev/null +++ b/facefusion/apis/endpoints/state.py @@ -0,0 +1,80 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND + +from facefusion import args_store, session_manager, state_manager, translator +from facefusion.apis import asset_store +from facefusion.apis.endpoints.session import extract_access_token + + +async def get_state(request : Request) -> JSONResponse: + api_args = args_store.filter_api_args(state_manager.get_state()) + return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK) + + +async def set_state(request : Request) -> JSONResponse: + action = request.query_params.get('action') + asset_type = request.query_params.get('asset_type') + + if action == 'select' and asset_type == 'source': + return await select_source(request) + + if action == 'select' and asset_type == 'target': + return await select_target(request) + + body = await request.json() + api_args = args_store.get_api_args() + + for key, value in body.items(): + if key in api_args: + state_manager.set_item(key, value) + + __api_args__ = args_store.filter_api_args(state_manager.get_state()) + return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) + + +async def select_source(request : Request) -> JSONResponse: + body = await request.json() + asset_ids = body.get('asset_ids') + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + + if isinstance(asset_ids, list) and session_id: + source_paths = [] + + for asset_id in asset_ids: + asset = asset_store.get_asset(session_id, asset_id) + + if asset: + source_paths.append(asset.get('path')) + + state_manager.set_item('source_paths', source_paths) + + __api_args__ = args_store.filter_api_args(state_manager.get_state()) + return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('source_asset_not_found', 'facefusion.apis') + }, status_code = HTTP_404_NOT_FOUND) + + +async def select_target(request : Request) -> JSONResponse: + body = await request.json() + asset_id = body.get('asset_id') + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + + if isinstance(asset_id, str) and session_id: + asset = asset_store.get_asset(session_id, asset_id) + + if asset: + state_manager.set_item('target_path', asset.get('path')) + + __api_args__ = args_store.filter_api_args(state_manager.get_state()) + return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) + + return JSONResponse( + { + 'message': translator.get('target_asset_not_found', 'facefusion.apis') + }, status_code = HTTP_404_NOT_FOUND) diff --git a/facefusion/apis/locales.py b/facefusion/apis/locales.py new file mode 100644 index 00000000..1f1c6ceb --- /dev/null +++ b/facefusion/apis/locales.py @@ -0,0 +1,14 @@ +from facefusion.types import Locales + +LOCALES : Locales =\ +{ + 'en': + { + 'ok': 'ok', + 'something_went_wrong': 'something went wrong', + 'invalid_access_token': 'invalid access token', + 'invalid_refresh_token': 'invalid refresh token', + 'source_asset_not_found': 'source asset not found', + 'target_asset_not_found': 'target asset not found' + } +} diff --git a/facefusion/apis/remote.py b/facefusion/apis/remote.py index e112b9c2..2fb434db 100644 --- a/facefusion/apis/remote.py +++ b/facefusion/apis/remote.py @@ -9,8 +9,10 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR -from facefusion import asset_store, logger +from facefusion import logger +from facefusion.apis import asset_store from facefusion.choices import audio_formats +from facefusion.session_context import get_session_id def resolve_image_urls(url : str) -> List[str]: @@ -66,12 +68,14 @@ def download_images_from_url(url : str, asset_type : str) -> List[str]: gdl_job = gallery_job.DownloadJob(url) gdl_job.run() + session_id = get_session_id() for root, dirs, files in os.walk(output_dir): for filename in files: file_path = os.path.join(root, filename) - asset_id = asset_store.register(asset_type, file_path, filename) - asset_ids.append(asset_id) - logger.info(f'Registered image as asset {asset_id}', __name__) + asset = asset_store.create_asset(session_id, asset_type, file_path) + if asset: + asset_ids.append(asset.get('id')) + logger.info(f'Registered image as asset {asset.get("id")}', __name__) break @@ -108,9 +112,11 @@ def download_images_from_url(url : str, asset_type : str) -> List[str]: for chunk in response.iter_bytes(chunk_size = 8192): f.write(chunk) - asset_id = asset_store.register(asset_type, file_path, filename) - asset_ids.append(asset_id) - logger.info(f'Downloaded and registered image as asset {asset_id}', __name__) + session_id = get_session_id() + asset = asset_store.create_asset(session_id, asset_type, file_path) + if asset: + asset_ids.append(asset.get('id')) + logger.info(f'Downloaded and registered image as asset {asset.get("id")}', __name__) return asset_ids @@ -138,9 +144,11 @@ def download_audio_from_url(url : str, asset_type : str) -> List[str]: for chunk in response.iter_bytes(chunk_size = 8192): f.write(chunk) - asset_id = asset_store.register(asset_type, file_path, filename) - asset_ids.append(asset_id) - logger.info(f'Downloaded and registered audio as asset {asset_id}', __name__) + session_id = get_session_id() + asset = asset_store.create_asset(session_id, asset_type, file_path) + if asset: + asset_ids.append(asset.get('id')) + logger.info(f'Downloaded and registered audio as asset {asset.get("id")}', __name__) return asset_ids @@ -352,16 +360,9 @@ async def remote(request : Request) -> JSONResponse: total_frames = int(duration * fps) logger.info(f'Calculated total frames: {total_frames} ({duration}s * {fps} fps)', __name__) - filename = os.path.basename(downloaded_file) - metadata =\ - { - 'frame_total': total_frames, - 'fps': fps, - 'resolution': (width, height) if width and height else None, - 'duration': duration - } - - asset_id = asset_store.register(asset_type, downloaded_file, filename, metadata) + session_id = get_session_id() + asset = asset_store.create_asset(session_id, asset_type, downloaded_file) + asset_id = asset.get('id') if asset else None logger.info(f'Video downloaded and registered as asset {asset_id}', __name__) response_data =\ diff --git a/facefusion/apis/state.py b/facefusion/apis/state.py deleted file mode 100644 index c440b901..00000000 --- a/facefusion/apis/state.py +++ /dev/null @@ -1,76 +0,0 @@ -from starlette.requests import Request -from starlette.responses import JSONResponse -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND - -from facefusion import args_store, asset_store, logger, state_manager - - -async def get_state(request : Request) -> JSONResponse: - api_args = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] - return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK) - - -async def set_state(request : Request) -> JSONResponse: - body = await request.json() - action = request.query_params.get('action') - - if action == 'select': - asset_type = request.query_params.get('asset_type') - - if not asset_type: - return JSONResponse({'message': 'Missing required query parameter: asset_type'}, status_code = HTTP_400_BAD_REQUEST) - - if asset_type not in ['source', 'target']: - return JSONResponse({'message': 'Invalid asset_type. Must be "source" or "target"'}, status_code = HTTP_400_BAD_REQUEST) - - if asset_type == 'source': - asset_ids = body.get('asset_ids', []) - - if not isinstance(asset_ids, list): - return JSONResponse({'message': 'asset_ids must be an array'}, status_code = HTTP_400_BAD_REQUEST) - - if not asset_ids: - return JSONResponse({'message': 'asset_ids cannot be empty'}, status_code = HTTP_400_BAD_REQUEST) - - source_paths = [] - for asset_id in asset_ids: - asset = asset_store.get_asset(asset_id) - if not asset: - return JSONResponse({'message': f'Source asset not found: {asset_id}'}, status_code = HTTP_404_NOT_FOUND) - source_paths.append(asset['path']) - - state_manager.set_item('source_paths', source_paths) - - __api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] - return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) - - if asset_type == 'target': - asset_id = body.get('asset_id') - - if not asset_id: - return JSONResponse({'message': 'Missing required field: asset_id'}, status_code = HTTP_400_BAD_REQUEST) - - if not isinstance(asset_id, str): - return JSONResponse({'message': 'asset_id must be a string'}, status_code = HTTP_400_BAD_REQUEST) - - asset = asset_store.get_asset(asset_id) - if not asset: - return JSONResponse({'message': f'Target asset not found: {asset_id}'}, status_code = HTTP_404_NOT_FOUND) - - state_manager.set_item('target_path', asset['path']) - - __api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] - return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) - - api_args = args_store.get_api_args() - logger.info(f'[State] Normal update - body keys: {list(body.keys())}', __name__) - - for key, value in body.items(): - if key in api_args: - state_manager.set_item(key, value) - logger.debug(f'[State] Set {key} = {value}', __name__) - else: - logger.warn(f'[State] Skipped {key} (not in api_args)', __name__) - - __api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type] - return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK) diff --git a/facefusion/apis/timeline.py b/facefusion/apis/timeline.py index 81afca8a..bf7d4a37 100644 --- a/facefusion/apis/timeline.py +++ b/facefusion/apis/timeline.py @@ -9,7 +9,7 @@ from starlette.responses import JSONResponse from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from facefusion import logger -from facefusion.asset_store import get_asset +from facefusion.apis import asset_store from facefusion.filesystem import is_video from facefusion.video_manager import get_video_capture from facefusion.vision import fit_contain_frame @@ -81,14 +81,11 @@ async def get_timeline(request: Request) -> JSONResponse: if asset_id and not target_path: from facefusion.session_context import get_session_id - asset = get_asset(asset_id) + session_id = get_session_id() + asset = asset_store.get_asset(session_id, asset_id) if not asset: return JSONResponse({'message': f'Asset not found: {asset_id}'}, status_code=HTTP_400_BAD_REQUEST) - # Verify asset belongs to current session (security) - if asset.get('session_id') != get_session_id(): - return JSONResponse({'message': 'Asset not found'}, status_code=HTTP_400_BAD_REQUEST) - target_path = asset.get('path') if not target_path: return JSONResponse({'message': 'Asset has no path'}, status_code=HTTP_400_BAD_REQUEST) diff --git a/facefusion/types.py b/facefusion/types.py index 429283ab..404fad0f 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -162,6 +162,70 @@ AudioTypeSet : TypeAlias = Dict[AudioFormat, str] ImageTypeSet : TypeAlias = Dict[ImageFormat, str] VideoTypeSet : TypeAlias = Dict[VideoFormat, str] +AssetId : TypeAlias = str +AssetType = Literal['source', 'target'] +AudioMetadata = TypedDict('AudioMetadata', +{ + 'duration' : Duration, + 'sample_rate': int, + 'frame_total': int, + 'channels': int, + 'format': str +}) +ImageMetadata = TypedDict('ImageMetadata', +{ + 'resolution' : Resolution +}) +VideoMetadata = TypedDict('VideoMetadata', +{ + 'duration' : Duration, + 'frame_total' : int, + 'fps' : Fps, + 'resolution' : Resolution +}) +AudioAsset = TypedDict('AudioAsset', +{ + 'id' : AssetId, + 'created_at' : datetime, + 'expires_at' : datetime, + 'type' : AssetType, + 'media' : Literal['audio'], + 'name' : str, + 'format' : AudioFormat, + 'size' : int, + 'path' : str, + 'metadata' : AudioMetadata +}) +ImageAsset = TypedDict('ImageAsset', +{ + 'id' : AssetId, + 'created_at' : datetime, + 'expires_at' : datetime, + 'type' : AssetType, + 'media' : Literal['image'], + 'name' : str, + 'format' : ImageFormat, + 'size' : int, + 'path' : str, + 'metadata' : ImageMetadata +}) +VideoAsset = TypedDict('VideoAsset', +{ + 'id' : AssetId, + 'created_at' : datetime, + 'expires_at' : datetime, + 'type' : AssetType, + 'media' : Literal['video'], + 'name' : str, + 'format' : VideoFormat, + 'size' : int, + 'path' : str, + 'metadata' : VideoMetadata +}) +AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata +AssetSet : TypeAlias = Dict[AssetId, AudioAsset | ImageAsset | VideoAsset] +AssetStore : TypeAlias = Dict[SessionId, AssetSet] + AudioEncoder = Literal['flac', 'aac', 'libmp3lame', 'libopus', 'libvorbis', 'pcm_s16le', 'pcm_s32le'] VideoEncoder = Literal['libx264', 'libx264rgb', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf', 'h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox', 'rawvideo'] EncoderSet = TypedDict('EncoderSet', diff --git a/tests/test_asset_store.py b/tests/test_asset_store.py index 60a9259b..3c2360f8 100644 --- a/tests/test_asset_store.py +++ b/tests/test_asset_store.py @@ -2,24 +2,24 @@ import os import tempfile from typing import Iterator +import cv2 +import numpy import pytest -from facefusion import asset_store, session_manager, state_manager -from facefusion.session_context import clear_session_id, set_session_id +from facefusion.apis import asset_store @pytest.fixture(scope = 'function', autouse = True) def before_each() -> None: - session_manager.SESSIONS.clear() - state_manager.clear_item('asset_registry') - clear_session_id() + asset_store.clear() @pytest.fixture(scope = 'function') def temp_file() -> Iterator[str]: fd, path = tempfile.mkstemp(suffix = '.jpg') - os.write(fd, b'test file content') os.close(fd) + image = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path, image) yield path if os.path.exists(path): os.remove(path) @@ -27,221 +27,171 @@ def temp_file() -> Iterator[str]: @pytest.fixture(scope = 'function') def session_id() -> str: - test_session_id = 'test-session-123' - set_session_id(test_session_id) - return test_session_id + return 'test-session-123' -def test_register_source_asset(temp_file : str, session_id : str) -> None: - asset_id = asset_store.register('source', temp_file, 'test.jpg') +def test_create_source_asset(temp_file : str, session_id : str) -> None: + asset = asset_store.create_asset(session_id, 'source', temp_file) - assert isinstance(asset_id, str) - assert len(asset_id) == 36 - - asset = asset_store.get_asset(asset_id) assert asset is not None - assert asset.get('id') == asset_id - assert asset.get('session_id') == session_id + assert isinstance(asset.get('id'), str) + assert len(asset.get('id')) == 36 assert asset.get('type') == 'source' assert asset.get('path') == temp_file - assert asset.get('filename') == 'test.jpg' assert asset.get('size') > 0 assert asset.get('created_at') -def test_register_target_asset(temp_file : str, session_id : str) -> None: - asset_id = asset_store.register('target', temp_file, 'video.mp4') +def test_create_target_asset(temp_file : str, session_id : str) -> None: + asset = asset_store.create_asset(session_id, 'target', temp_file) - asset = asset_store.get_asset(asset_id) + assert asset is not None assert asset.get('type') == 'target' - assert asset.get('filename') == 'video.mp4' -def test_register_output_asset(temp_file : str, session_id : str) -> None: - metadata = {'fps': 30, 'resolution': [1920, 1080]} - asset_id = asset_store.register('output', temp_file, 'output.mp4', metadata) +def test_get_asset(temp_file : str, session_id : str) -> None: + created_asset = asset_store.create_asset(session_id, 'source', temp_file) + asset_id = created_asset.get('id') - asset = asset_store.get_asset(asset_id) - assert asset.get('type') == 'output' - assert asset.get('metadata') == metadata - - -def test_register_invalid_type(temp_file : str, session_id : str) -> None: - with pytest.raises(ValueError) as exc: - asset_store.register('invalid_type', temp_file, 'test.jpg') - assert "Invalid asset_type" in str(exc.value) - - -def test_register_without_session() -> None: - fd, path = tempfile.mkstemp() - os.close(fd) - - try: - with pytest.raises(ValueError) as exc: - asset_store.register('source', path, 'test.jpg') - assert "No active session" in str(exc.value) - finally: - os.remove(path) - - -def test_register_without_filename(temp_file : str, session_id : str) -> None: - asset_id = asset_store.register('source', temp_file) - - asset = asset_store.get_asset(asset_id) - assert asset.get('filename') == os.path.basename(temp_file) + asset = asset_store.get_asset(session_id, asset_id) + assert asset is not None + assert asset.get('id') == asset_id + assert asset.get('type') == 'source' def test_get_asset_not_found(session_id : str) -> None: - asset = asset_store.get_asset('non-existent-id') + asset = asset_store.get_asset(session_id, 'non-existent-id') assert asset is None -def test_list_assets_empty(session_id : str) -> None: - assets = asset_store.list_assets() - assert assets == [] +def test_get_asset_wrong_session(temp_file : str, session_id : str) -> None: + created_asset = asset_store.create_asset(session_id, 'source', temp_file) + asset_id = created_asset.get('id') + + asset = asset_store.get_asset('different-session', asset_id) + assert asset is None -def test_list_assets_with_multiple(temp_file : str, session_id : str) -> None: +def test_get_assets_empty(session_id : str) -> None: + assets = asset_store.get_assets(session_id) + assert assets is None + + +def test_get_assets_with_multiple(temp_file : str, session_id : str) -> None: fd1, path1 = tempfile.mkstemp(suffix = '.jpg') - os.write(fd1, b'content 1') os.close(fd1) + image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path1, image1) - fd2, path2 = tempfile.mkstemp(suffix = '.mp4') - os.write(fd2, b'content 2') + fd2, path2 = tempfile.mkstemp(suffix = '.jpg') os.close(fd2) + image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path2, image2) try: - asset_store.register('source', path1, 'source1.jpg') - asset_store.register('source', path2, 'source2.jpg') - asset_store.register('target', temp_file, 'target.mp4') + asset_store.create_asset(session_id, 'source', path1) + asset_store.create_asset(session_id, 'source', path2) + asset_store.create_asset(session_id, 'target', temp_file) - assets = asset_store.list_assets() + assets = asset_store.get_assets(session_id) + assert assets is not None assert len(assets) == 3 finally: - os.remove(path1) - os.remove(path2) + if os.path.exists(path1): + os.remove(path1) + if os.path.exists(path2): + os.remove(path2) -def test_list_assets_filter_by_type(temp_file : str, session_id : str) -> None: - fd, path = tempfile.mkstemp(suffix = '.jpg') - os.write(fd, b'content') - os.close(fd) - - try: - asset_store.register('source', path, 'source.jpg') - asset_store.register('target', temp_file, 'target.mp4') - - source_assets = asset_store.list_assets('source') - assert len(source_assets) == 1 - assert source_assets[0].get('type') == 'source' - - target_assets = asset_store.list_assets('target') - assert len(target_assets) == 1 - assert target_assets[0].get('type') == 'target' - - output_assets = asset_store.list_assets('output') - assert len(output_assets) == 0 - finally: - os.remove(path) - - -def test_list_assets_invalid_type(session_id : str) -> None: - with pytest.raises(ValueError) as exc: - asset_store.list_assets('invalid_type') - assert "Invalid asset_type" in str(exc.value) - - -def test_list_assets_session_scoped(temp_file : str) -> None: +def test_get_assets_session_scoped(temp_file : str) -> None: session1_id = 'session-1' - set_session_id(session1_id) - asset1_id = asset_store.register('source', temp_file, 'file1.jpg') + asset1 = asset_store.create_asset(session1_id, 'source', temp_file) + asset1_id = asset1.get('id') session2_id = 'session-2' - set_session_id(session2_id) fd, path2 = tempfile.mkstemp(suffix = '.jpg') - os.write(fd, b'content 2') os.close(fd) + image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path2, image2) try: - asset2_id = asset_store.register('source', path2, 'file2.jpg') + asset2 = asset_store.create_asset(session2_id, 'source', path2) + asset2_id = asset2.get('id') - assets_session2 = asset_store.list_assets() + assets_session2 = asset_store.get_assets(session2_id) + assert assets_session2 is not None assert len(assets_session2) == 1 - assert assets_session2[0].get('id') == asset2_id + assert asset2_id in assets_session2 - set_session_id(session1_id) - assets_session1 = asset_store.list_assets() + assets_session1 = asset_store.get_assets(session1_id) + assert assets_session1 is not None assert len(assets_session1) == 1 - assert assets_session1[0].get('id') == asset1_id + assert asset1_id in assets_session1 finally: - os.remove(path2) + if os.path.exists(path2): + os.remove(path2) -def test_delete_asset(temp_file : str, session_id : str) -> None: - asset_id = asset_store.register('source', temp_file, 'test.jpg') - - assert os.path.exists(temp_file) - - success = asset_store.delete_asset(asset_id) - assert success is True - - assert not os.path.exists(temp_file) - - asset = asset_store.get_asset(asset_id) - assert asset is None - - -def test_delete_asset_not_found(session_id : str) -> None: - success = asset_store.delete_asset('non-existent-id') - assert success is False - - -def test_cleanup_session_assets(session_id : str) -> None: +def test_delete_assets(session_id : str) -> None: fd1, path1 = tempfile.mkstemp(suffix = '.jpg') - os.write(fd1, b'content 1') os.close(fd1) + image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path1, image1) - fd2, path2 = tempfile.mkstemp(suffix = '.mp4') - os.write(fd2, b'content 2') + fd2, path2 = tempfile.mkstemp(suffix = '.jpg') os.close(fd2) + image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path2, image2) - asset_id1 = asset_store.register('source', path1, 'source.jpg') - asset_id2 = asset_store.register('target', path2, 'target.mp4') + try: + asset1 = asset_store.create_asset(session_id, 'source', path1) + asset2 = asset_store.create_asset(session_id, 'source', path2) + asset1_id = asset1.get('id') + asset2_id = asset2.get('id') - assert os.path.exists(path1) - assert os.path.exists(path2) + asset_store.delete_assets(session_id, [asset1_id]) - asset_store.cleanup_session_assets(session_id) - - assert not os.path.exists(path1) - assert not os.path.exists(path2) - - assert asset_store.get_asset(asset_id1) is None - assert asset_store.get_asset(asset_id2) is None + assets = asset_store.get_assets(session_id) + assert assets is not None + assert len(assets) == 1 + assert asset2_id in assets + assert asset1_id not in assets + finally: + if os.path.exists(path1): + os.remove(path1) + if os.path.exists(path2): + os.remove(path2) -def test_cleanup_session_assets_only_affects_target_session(temp_file : str) -> None: - session1_id = 'session-1' - set_session_id(session1_id) +def test_delete_assets_not_found(session_id : str) -> None: + asset_store.delete_assets(session_id, ['non-existent-id']) - fd, path1 = tempfile.mkstemp(suffix = '.jpg') - os.write(fd, b'content 1') - os.close(fd) - asset1_id = asset_store.register('source', path1, 'file1.jpg') +def test_clear() -> None: + fd1, path1 = tempfile.mkstemp(suffix = '.jpg') + os.close(fd1) + image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path1, image1) - session2_id = 'session-2' - set_session_id(session2_id) - asset2_id = asset_store.register('source', temp_file, 'file2.jpg') + fd2, path2 = tempfile.mkstemp(suffix = '.jpg') + os.close(fd2) + image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8) + cv2.imwrite(path2, image2) - asset_store.cleanup_session_assets(session1_id) + try: + session1_id = 'session-1' + session2_id = 'session-2' - assert not os.path.exists(path1) - assert os.path.exists(temp_file) + asset_store.create_asset(session1_id, 'source', path1) + asset_store.create_asset(session2_id, 'source', path2) - set_session_id(session1_id) - assert asset_store.get_asset(asset1_id) is None + asset_store.clear() - set_session_id(session2_id) - assert asset_store.get_asset(asset2_id) is not None + assert asset_store.get_assets(session1_id) is None + assert asset_store.get_assets(session2_id) is None + finally: + if os.path.exists(path1): + os.remove(path1) + if os.path.exists(path2): + os.remove(path2)