From 771ffe308117ff25faaacc96fc00510e94a09a92 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 12 Jan 2026 18:00:43 +0530 Subject: [PATCH] asset store update --- facefusion/apis/api_helper.py | 30 +++++++++++ facefusion/apis/asset_store.py | 81 ++++++++++++++++++++++++++---- facefusion/apis/endpoints/state.py | 15 ++++-- facefusion/types.py | 62 +++++++++++++++++++++++ tests/test_api_state.py | 66 ++++++++++++++++-------- 5 files changed, 216 insertions(+), 38 deletions(-) diff --git a/facefusion/apis/api_helper.py b/facefusion/apis/api_helper.py index 65047146..30a29747 100644 --- a/facefusion/apis/api_helper.py +++ b/facefusion/apis/api_helper.py @@ -3,6 +3,10 @@ from typing import Optional from starlette.datastructures import Headers from starlette.types import Scope +from facefusion.audio import detect_audio_duration +from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata +from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution + def get_sec_websocket_protocol(scope : Scope) -> Optional[str]: protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol') @@ -13,3 +17,29 @@ def get_sec_websocket_protocol(scope : Scope) -> Optional[str]: return None + +def extract_audio_metadata(file_path : str) -> AudioMetadata: + metadata : AudioMetadata =\ + { + 'duration' : detect_audio_duration(file_path) + } + return metadata + + +def extract_image_metadata(file_path : str) -> ImageMetadata: + metadata : ImageMetadata =\ + { + 'resolution' : detect_image_resolution(file_path) + } + return metadata + + +def extract_video_metadata(file_path : str) -> VideoMetadata: + metadata : VideoMetadata =\ + { + 'duration' : detect_video_duration(file_path), + 'frame_total' : count_video_frame_total(file_path), + 'fps' : detect_video_fps(file_path), + 'resolution' : detect_video_resolution(file_path) + } + return metadata diff --git a/facefusion/apis/asset_store.py b/facefusion/apis/asset_store.py index a9f0c2fc..c437dc73 100644 --- a/facefusion/apis/asset_store.py +++ b/facefusion/apis/asset_store.py @@ -1,23 +1,82 @@ +import os import uuid -from typing import Dict, Optional +from datetime import datetime, timedelta +from typing import Optional, cast -ASSET_STORE : Dict[str, Dict[str, str]] = {} +from facefusion.apis.api_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata +from facefusion.filesystem import get_file_format, get_file_name, is_audio, is_image, is_video +from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioFormat, ImageFormat, MediaType, SessionId, VideoFormat + +ASSET_STORE : AssetStore = {} -def get_asset(asset_id : str) -> Optional[Dict[str, str]]: - return ASSET_STORE.get(asset_id) +def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]: + if session_id in ASSET_STORE: + return ASSET_STORE[session_id].get(asset_id) + return None -def register_asset(path : str) -> str: +def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[Asset]: asset_id = str(uuid.uuid4()) + media_type = detect_media_type(file_path) - ASSET_STORE[asset_id] =\ - { - 'id': asset_id, - 'path': path - } - return asset_id + if media_type: + file_name = get_file_name(file_path) + file_format = get_file_format(file_path) + file_size = os.path.getsize(file_path) + created_at = datetime.now() + expires_at = created_at + timedelta(hours = 24) + asset =\ + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'name': file_name, + 'size': file_size, + 'path': file_path + } + + if media_type == 'audio': + asset.update( + { + 'media': 'audio', + 'format': cast(AudioFormat, file_format), + 'metadata': extract_audio_metadata(file_path) + }) + if media_type == 'image': + asset.update( + { + 'media': 'image', + 'format': cast(ImageFormat, file_format), + 'metadata': extract_image_metadata(file_path) + }) + if media_type == 'video': + asset.update( + { + 'media': 'video', + 'format': cast(VideoFormat, file_format), + 'metadata': extract_video_metadata(file_path) + }) + + if session_id not in ASSET_STORE: + ASSET_STORE[session_id] = {} + + ASSET_STORE[session_id][asset_id] = asset #type:ignore[assignment] + return asset #type:ignore[return-value] + + return None def clear() -> None: ASSET_STORE.clear() + + +def detect_media_type(file_path : str) -> Optional[MediaType]: + if is_image(file_path): + return 'image' + if is_video(file_path): + return 'video' + if is_audio(file_path): + return 'audio' + return None diff --git a/facefusion/apis/endpoints/state.py b/facefusion/apis/endpoints/state.py index 5875f5a0..db96c5c6 100644 --- a/facefusion/apis/endpoints/state.py +++ b/facefusion/apis/endpoints/state.py @@ -2,8 +2,9 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND -from facefusion import args_store, state_manager, translator +from facefusion import args_store, session_manager, state_manager, translator from facefusion.apis import asset_store +from facefusion.apis.endpoints.session import extract_access_token async def get_state(request : Request) -> JSONResponse: @@ -35,12 +36,14 @@ async def set_state(request : Request) -> JSONResponse: async def select_source(request : Request) -> JSONResponse: body = await request.json() asset_ids = body.get('asset_ids') + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) - if isinstance(asset_ids, list): + if isinstance(asset_ids, list) and session_id: source_paths = [] for asset_id in asset_ids: - asset = asset_store.get_asset(asset_id) + asset = asset_store.get_asset(session_id, asset_id) if asset: source_paths.append(asset.get('path')) @@ -59,9 +62,11 @@ async def select_source(request : Request) -> JSONResponse: async def select_target(request : Request) -> JSONResponse: body = await request.json() asset_id = body.get('asset_id') + access_token = extract_access_token(request.scope) + session_id = session_manager.find_session_id(access_token) - if isinstance(asset_id, str): - asset = asset_store.get_asset(asset_id) + if isinstance(asset_id, str) and session_id: + asset = asset_store.get_asset(session_id, asset_id) if asset: state_manager.set_item('target_path', asset.get('path')) diff --git a/facefusion/types.py b/facefusion/types.py index c3371f83..3f67ac1b 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -169,6 +169,68 @@ EncoderSet = TypedDict('EncoderSet', }) VideoPreset = Literal['ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow'] +AssetId : TypeAlias = str +AssetType = Literal['source', 'target'] +MediaType = Literal['image', 'video', 'audio'] +AudioMetadata = TypedDict('AudioMetadata', +{ + 'duration' : Duration +}) +ImageMetadata = TypedDict('ImageMetadata', +{ + 'resolution' : Resolution +}) +VideoMetadata = TypedDict('VideoMetadata', +{ + 'duration' : Duration, + 'frame_total' : int, + 'fps' : Fps, + 'resolution' : Resolution +}) +AudioAsset = TypedDict('AudioAsset', +{ + 'id' : AssetId, + 'created_at' : datetime, + 'expires_at' : datetime, + 'type' : AssetType, + 'media' : Literal['audio'], + 'name' : str, + 'format' : AudioFormat, + 'size' : int, + 'path' : str, + 'metadata' : AudioMetadata +}) +ImageAsset = TypedDict('ImageAsset', +{ + 'id' : AssetId, + 'created_at' : datetime, + 'expires_at' : datetime, + 'type' : AssetType, + 'media' : Literal['image'], + 'name' : str, + 'format' : ImageFormat, + 'size' : int, + 'path' : str, + 'metadata' : ImageMetadata +}) +VideoAsset = TypedDict('VideoAsset', +{ + 'id' : AssetId, + 'created_at' : datetime, + 'expires_at' : datetime, + 'type' : AssetType, + 'media' : Literal['video'], + 'name' : str, + 'format' : VideoFormat, + 'size' : int, + 'path' : str, + 'metadata' : VideoMetadata +}) + +Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset +AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata +AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]] + BenchmarkMode = Literal['warm', 'cold'] BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p'] BenchmarkSet : TypeAlias = Dict[BenchmarkResolution, str] diff --git a/tests/test_api_state.py b/tests/test_api_state.py index b2cd0fe0..1a6f857d 100644 --- a/tests/test_api_state.py +++ b/tests/test_api_state.py @@ -1,3 +1,4 @@ +import subprocess from typing import Iterator import pytest @@ -6,6 +7,18 @@ from starlette.testclient import TestClient from facefusion import args_store, metadata, session_manager, state_manager from facefusion.apis import asset_store from facefusion.apis.core import create_api +from facefusion.download import conditional_download +from .helper import get_test_example_file, get_test_examples_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) @pytest.fixture(scope = 'module') @@ -85,10 +98,23 @@ def test_set_state(test_client : TestClient) -> None: def test_select_source_assets(test_client : TestClient) -> None: + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + + create_session_body = create_session_response.json() + access_token = create_session_body.get('access_token') + session_id = session_manager.find_session_id(access_token) + source_paths =\ + [ + get_test_example_file('source.jpg'), + get_test_example_file('source.jpg') + ] asset_ids =\ [ - asset_store.register_asset('/path/to/source1.jpg'), - asset_store.register_asset('/path/to/source2.jpg') + asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'), + asset_store.create_asset(session_id, 'source', source_paths[1]).get('id') ] select_response = test_client.put('/state?action=select&type=source', json = @@ -98,18 +124,12 @@ def test_select_source_assets(test_client : TestClient) -> None: assert select_response.status_code == 401 - create_session_response = test_client.post('/session', json = - { - 'client_version': metadata.get('version') - }) - create_session_body = create_session_response.json() - select_response = test_client.put('/state?action=select&type=source', json = { 'asset_ids': 'invalid' }, headers = { - 'Authorization': 'Bearer ' + create_session_body.get('access_token') + 'Authorization': 'Bearer ' + access_token }) assert select_response.status_code == 404 @@ -119,36 +139,38 @@ def test_select_source_assets(test_client : TestClient) -> None: 'asset_ids': asset_ids }, headers = { - 'Authorization': 'Bearer ' + create_session_body.get('access_token') + 'Authorization': 'Bearer ' + access_token }) select_body = select_response.json() - assert select_body.get('source_paths') == [ '/path/to/source1.jpg', '/path/to/source2.jpg' ] + assert select_body.get('source_paths') == source_paths assert select_response.status_code == 200 def test_select_target_assets(test_client : TestClient) -> None: - asset_id = asset_store.register_asset('/path/to/target.jpg') + create_session_response = test_client.post('/session', json = + { + 'client_version': metadata.get('version') + }) + create_session_body = create_session_response.json() + access_token = create_session_body.get('access_token') + session_id = session_manager.find_session_id(access_token) + target_path = get_test_example_file('target-240p.jpg') + asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id') - select_response = test_client.put('/state?action=select&type=target', json = + select_response = test_client.put('/state?action=select&type=target', json= { 'asset_id': asset_id }) assert select_response.status_code == 401 - create_session_response = test_client.post('/session', json = - { - 'client_version': metadata.get('version') - }) - create_session_body = create_session_response.json() - select_response = test_client.put('/state?action=select&type=target', json = { 'asset_id': 'invalid' }, headers = { - 'Authorization': 'Bearer ' + create_session_body.get('access_token') + 'Authorization': 'Bearer ' + access_token }) assert select_response.status_code == 404 @@ -158,9 +180,9 @@ def test_select_target_assets(test_client : TestClient) -> None: 'asset_id': asset_id }, headers = { - 'Authorization': 'Bearer ' + create_session_body.get('access_token') + 'Authorization': 'Bearer ' + access_token }) select_body = select_response.json() - assert select_body.get('target_path') == '/path/to/target.jpg' + assert select_body.get('target_path') == target_path assert select_response.status_code == 200