From a09c078c9089082ba5d9c30cdae736fd5a407dc2 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Fri, 16 Jan 2026 17:05:17 +0530 Subject: [PATCH] upload asset endpoint --- facefusion/apis/asset_helper.py | 35 ++---- facefusion/apis/asset_store.py | 129 +++++++++++----------- facefusion/apis/core.py | 2 + facefusion/apis/endpoints/assets.py | 75 +++++++++++++ facefusion/types.py | 3 +- tests/test_api_assets.py | 160 ++++++++++++++++++++++++++++ tests/test_api_state.py | 6 +- 7 files changed, 321 insertions(+), 89 deletions(-) create mode 100644 facefusion/apis/endpoints/assets.py create mode 100644 tests/test_api_assets.py diff --git a/facefusion/apis/asset_helper.py b/facefusion/apis/asset_helper.py index 979d6d83..80989ef0 100644 --- a/facefusion/apis/asset_helper.py +++ b/facefusion/apis/asset_helper.py @@ -1,20 +1,17 @@ -from typing import Optional - from facefusion.audio import detect_audio_duration from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate -from facefusion.filesystem import is_audio, is_image, is_video -from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata +from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution def extract_audio_metadata(file_path : str) -> AudioMetadata: metadata : AudioMetadata =\ { - 'duration' : detect_audio_duration(file_path), - 'sample_rate' : detect_audio_sample_rate(file_path), - 'frame_total' : detect_audio_frame_total(file_path), - 'channels' : detect_audio_channel_total(file_path), - 'format' : detect_audio_format(file_path) + 'duration': detect_audio_duration(file_path), + 'sample_rate': detect_audio_sample_rate(file_path), + 'frame_total': detect_audio_frame_total(file_path), + 'channels': detect_audio_channel_total(file_path), + 'format': detect_audio_format(file_path) } return metadata @@ -22,7 +19,7 @@ def extract_audio_metadata(file_path : str) -> AudioMetadata: def extract_image_metadata(file_path : str) -> ImageMetadata: metadata : ImageMetadata =\ { - 'resolution' : detect_image_resolution(file_path) + 'resolution': detect_image_resolution(file_path) } return metadata @@ -30,19 +27,9 @@ def extract_image_metadata(file_path : str) -> ImageMetadata: def extract_video_metadata(file_path : str) -> VideoMetadata: metadata : VideoMetadata =\ { - 'duration' : detect_video_duration(file_path), - 'frame_total' : count_video_frame_total(file_path), - 'fps' : detect_video_fps(file_path), - 'resolution' : detect_video_resolution(file_path) + 'duration': detect_video_duration(file_path), + 'frame_total': count_video_frame_total(file_path), + 'fps': detect_video_fps(file_path), + 'resolution': detect_video_resolution(file_path) } return metadata - - -def detect_media_type(file_path : str) -> Optional[MediaType]: - if is_audio(file_path): - return 'audio' - if is_image(file_path): - return 'image' - if is_video(file_path): - return 'video' - return None diff --git a/facefusion/apis/asset_store.py b/facefusion/apis/asset_store.py index 66520e13..2da64b0d 100644 --- a/facefusion/apis/asset_store.py +++ b/facefusion/apis/asset_store.py @@ -3,82 +3,89 @@ import uuid from datetime import datetime, timedelta from typing import Optional, cast -from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata -from facefusion.filesystem import get_file_format, get_file_name -from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat +from facefusion.apis.asset_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata +from facefusion.filesystem import get_file_format, get_file_name, is_file, remove_file +from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, MediaType, SessionId, VideoAsset, VideoFormat ASSET_STORE : AssetStore = {} -def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]: +def create_asset(session_id : SessionId, asset_type : AssetType, media_type : MediaType, file_path : str) -> Asset: asset_id = str(uuid.uuid4()) - media_type = detect_media_type(file_path) + file_name = get_file_name(file_path) + file_format = get_file_format(file_path) + file_size = os.path.getsize(file_path) + created_at = datetime.now() + expires_at = created_at + timedelta(hours = 2) - if media_type: - file_name = get_file_name(file_path) - file_format = get_file_format(file_path) - file_size = os.path.getsize(file_path) - created_at = datetime.now() - expires_at = created_at + timedelta(hours = 2) + if session_id not in ASSET_STORE: + ASSET_STORE[session_id] = {} - if session_id not in ASSET_STORE: - ASSET_STORE[session_id] = {} + if media_type == 'audio': + ASSET_STORE[session_id][asset_id] = cast(AudioAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': file_name, + 'format': cast(AudioFormat, file_format), + 'size': file_size, + 'path': file_path, + 'metadata': extract_audio_metadata(file_path) + }) - if media_type == 'audio': - ASSET_STORE[session_id][asset_id] = cast(AudioAsset, - { - 'id': asset_id, - 'created_at': created_at, - 'expires_at': expires_at, - 'type': asset_type, - 'media': media_type, - 'name': file_name, - 'format': cast(AudioFormat, file_format), - 'size': file_size, - 'path': file_path, - 'metadata': extract_audio_metadata(file_path) - }) + if media_type == 'image': + ASSET_STORE[session_id][asset_id] = cast(ImageAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': file_name, + 'format': cast(ImageFormat, file_format), + 'size': file_size, + 'path': file_path, + 'metadata': extract_image_metadata(file_path) + }) - if media_type == 'image': - ASSET_STORE[session_id][asset_id] = cast(ImageAsset, - { - 'id': asset_id, - 'created_at': created_at, - 'expires_at': expires_at, - 'type': asset_type, - 'media': media_type, - 'name': file_name, - 'format': cast(ImageFormat, file_format), - 'size': file_size, - 'path': file_path, - 'metadata': extract_image_metadata(file_path) - }) + if media_type == 'video': + ASSET_STORE[session_id][asset_id] = cast(VideoAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': file_name, + 'format': cast(VideoFormat, file_format), + 'size': file_size, + 'path': file_path, + 'metadata': extract_video_metadata(file_path) + }) - if media_type == 'video': - ASSET_STORE[session_id][asset_id] = cast(VideoAsset, - { - 'id': asset_id, - 'created_at': created_at, - 'expires_at': expires_at, - 'type': asset_type, - 'media': media_type, - 'name': file_name, - 'format': cast(VideoFormat, file_format), - 'size': file_size, - 'path': file_path, - 'metadata': extract_video_metadata(file_path) - }) - - return ASSET_STORE[session_id][asset_id] - - return None + return ASSET_STORE[session_id][asset_id] -def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]: +def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]: if session_id in ASSET_STORE: return ASSET_STORE[session_id].get(asset_id) return None +def clear_session(session_id : SessionId) -> None: + if session_id in ASSET_STORE: + for asset in ASSET_STORE[session_id].values(): + file_path = asset.get('path') + + if file_path and is_file(file_path): + remove_file(file_path) + + del ASSET_STORE[session_id] + + def clear() -> None: - ASSET_STORE.clear() + for session_id in list(ASSET_STORE.keys()): + clear_session(session_id) diff --git a/facefusion/apis/core.py b/facefusion/apis/core.py index 45299786..8692c4bb 100644 --- a/facefusion/apis/core.py +++ b/facefusion/apis/core.py @@ -3,6 +3,7 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route, WebSocketRoute +from facefusion.apis.endpoints.assets import upload_asset from facefusion.apis.endpoints.ping import websocket_ping from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session from facefusion.apis.endpoints.state import get_state, set_state @@ -18,6 +19,7 @@ def create_api() -> Starlette: Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]), Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]), Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]), + Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]), WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ]) ] diff --git a/facefusion/apis/endpoints/assets.py b/facefusion/apis/endpoints/assets.py new file mode 100644 index 00000000..7203f5f1 --- /dev/null +++ b/facefusion/apis/endpoints/assets.py @@ -0,0 +1,75 @@ +import tempfile +from typing import List, Tuple + +from starlette.datastructures import UploadFile +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST + +from facefusion import session_manager +from facefusion.apis import asset_store +from facefusion.apis.endpoints.session import extract_access_token +from facefusion.filesystem import get_file_extension, is_audio, is_file, is_image, is_video, remove_file +from facefusion.types import MediaType + + +async def upload_asset(request : Request) -> JSONResponse: + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) + asset_type = request.query_params.get('type') + + if session_id and asset_type in [ 'source', 'target' ]: + form = await request.form() + files = [ file for file in form.getlist('file') if isinstance(file, UploadFile) ] + + if asset_type == 'target': + files = files[:1] + + prepared_files = await prepare_files(files) + + if prepared_files: + asset_ids : List[str] = [] + + for file_path, media_type in prepared_files: + asset = asset_store.create_asset(session_id, asset_type, media_type, file_path) #type:ignore[arg-type] + + if asset: + asset_id = asset.get('id') + + if asset_id: + asset_ids.append(asset_id) + + if asset_ids: + if asset_type == 'target': + return JSONResponse({ 'asset_id': asset_ids[0] }, status_code = HTTP_201_CREATED) + + return JSONResponse({ 'asset_ids': asset_ids }, status_code = HTTP_201_CREATED) + + return JSONResponse({}, status_code = HTTP_400_BAD_REQUEST) + + +async def prepare_files(files : List[UploadFile]) -> List[Tuple[str, MediaType]]: + prepared_files : List[Tuple[str, MediaType]] = [] + + for file in files: + file_extension = get_file_extension(file.filename) + + with tempfile.NamedTemporaryFile(suffix = file_extension, delete = False) as temp_file: + content = await file.read() + temp_file.write(content) + file_path = temp_file.name + + if is_audio(file_path): + prepared_files.append((file_path, 'audio')) + continue + if is_image(file_path): + prepared_files.append((file_path, 'image')) + continue + if is_video(file_path): + prepared_files.append((file_path, 'video')) + continue + + if is_file(file_path): + remove_file(file_path) + + return prepared_files diff --git a/facefusion/types.py b/facefusion/types.py index cee45f81..cc7f9ee3 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -232,7 +232,8 @@ VideoAsset = TypedDict('VideoAsset', }) AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata -AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]] +Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset +AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]] BenchmarkMode = Literal['warm', 'cold'] BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p'] diff --git a/tests/test_api_assets.py b/tests/test_api_assets.py new file mode 100644 index 00000000..e0eb2c0f --- /dev/null +++ b/tests/test_api_assets.py @@ -0,0 +1,160 @@ +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from facefusion import metadata, session_manager +from facefusion.apis import asset_store +from facefusion.apis.core import create_api + + +@pytest.fixture(scope = 'module') +def test_client() -> Iterator[TestClient]: + with TestClient(create_api()) as test_client: + yield test_client + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + session_manager.SESSIONS.clear() + asset_store.clear() + + +def test_upload_asset_without_auth(test_client : TestClient) -> None: + upload_response = test_client.post('/assets?type=source') + + assert upload_response.status_code == 401 + + +def test_upload_asset_invalid_type(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + upload_response = test_client.post('/assets?type=invalid', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert upload_response.status_code == 400 + + +def test_upload_asset_no_file(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + upload_response = test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }) + + assert upload_response.status_code == 400 + + +def test_upload_source_asset(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + with open('.assets/examples/source.jpg', 'rb') as source_file: + upload_response = test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }, files = + { + 'file': ('source.jpg', source_file, 'image/jpeg') + }) + + assert upload_response.status_code == 201 + assert upload_response.json().get('asset_ids') + assert len(upload_response.json().get('asset_ids')) == 1 + + +def test_upload_multiple_source_assets(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + with open('.assets/examples/source.jpg', 'rb') as source_file_1: + with open('.assets/examples/source.jpg', 'rb') as source_file_2: + upload_response = test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }, files = + [ + ('file', ('source1.jpg', source_file_1, 'image/jpeg')), + ('file', ('source2.jpg', source_file_2, 'image/jpeg')) + ]) + + assert upload_response.status_code == 201 + assert upload_response.json().get('asset_ids') + assert len(upload_response.json().get('asset_ids')) == 2 + + +def test_upload_target_asset(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + with open('.assets/examples/target-240p.mp4', 'rb') as target_file: + upload_response = test_client.post('/assets?type=target', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }, files = + { + 'file': ('target.mp4', target_file, 'video/mp4') + }) + + assert upload_response.status_code == 201 + assert upload_response.json().get('asset_id') + + +def test_upload_target_multiple_files_uses_first(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + with open('.assets/examples/target-240p.mp4', 'rb') as target_file_1: + with open('.assets/examples/target-240p.mp4', 'rb') as target_file_2: + upload_response = test_client.post('/assets?type=target', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }, files = + [ + ('file', ('target1.mp4', target_file_1, 'video/mp4')), + ('file', ('target2.mp4', target_file_2, 'video/mp4')) + ]) + + assert upload_response.status_code == 201 + assert upload_response.json().get('asset_id') + + +def test_upload_unsupported_format(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + + upload_response = test_client.post('/assets?type=source', headers = + { + 'Authorization': 'Bearer ' + create_session_body.get('access_token') + }, files = + { + 'file': ('test.txt', b'invalid content', 'text/plain') + }) + + assert upload_response.status_code == 400 diff --git a/tests/test_api_state.py b/tests/test_api_state.py index 1a6f857d..f8c34113 100644 --- a/tests/test_api_state.py +++ b/tests/test_api_state.py @@ -113,8 +113,8 @@ def test_select_source_assets(test_client : TestClient) -> None: ] asset_ids =\ [ - asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'), - asset_store.create_asset(session_id, 'source', source_paths[1]).get('id') + asset_store.create_asset(session_id, 'source', 'image', source_paths[0]).get('id'), + asset_store.create_asset(session_id, 'source', 'image', source_paths[1]).get('id') ] select_response = test_client.put('/state?action=select&type=source', json = @@ -156,7 +156,7 @@ def test_select_target_assets(test_client : TestClient) -> None: access_token = create_session_body.get('access_token') session_id = session_manager.find_session_id(access_token) target_path = get_test_example_file('target-240p.jpg') - asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id') + asset_id = asset_store.create_asset(session_id, 'target', 'image', target_path).get('id') select_response = test_client.put('/state?action=select&type=target', json= {