asset store update

This commit is contained in:
harisreedhar
2026-01-12 18:00:43 +05:30
committed by henryruhs
parent f59bf04bc2
commit 771ffe3081
5 changed files with 216 additions and 38 deletions
+30
View File
@@ -3,6 +3,10 @@ from typing import Optional
from starlette.datastructures import Headers
from starlette.types import Scope
from facefusion.audio import detect_audio_duration
from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol')
@@ -13,3 +17,29 @@ def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
return None
def extract_audio_metadata(file_path : str) -> AudioMetadata:
metadata : AudioMetadata =\
{
'duration' : detect_audio_duration(file_path)
}
return metadata
def extract_image_metadata(file_path : str) -> ImageMetadata:
metadata : ImageMetadata =\
{
'resolution' : detect_image_resolution(file_path)
}
return metadata
def extract_video_metadata(file_path : str) -> VideoMetadata:
metadata : VideoMetadata =\
{
'duration' : detect_video_duration(file_path),
'frame_total' : count_video_frame_total(file_path),
'fps' : detect_video_fps(file_path),
'resolution' : detect_video_resolution(file_path)
}
return metadata
+70 -11
View File
@@ -1,23 +1,82 @@
import os
import uuid
from typing import Dict, Optional
from datetime import datetime, timedelta
from typing import Optional, cast
ASSET_STORE : Dict[str, Dict[str, str]] = {}
from facefusion.apis.api_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata
from facefusion.filesystem import get_file_format, get_file_name, is_audio, is_image, is_video
from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioFormat, ImageFormat, MediaType, SessionId, VideoFormat
ASSET_STORE : AssetStore = {}
def get_asset(asset_id : str) -> Optional[Dict[str, str]]:
return ASSET_STORE.get(asset_id)
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]:
if session_id in ASSET_STORE:
return ASSET_STORE[session_id].get(asset_id)
return None
def register_asset(path : str) -> str:
def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[Asset]:
asset_id = str(uuid.uuid4())
media_type = detect_media_type(file_path)
ASSET_STORE[asset_id] =\
{
'id': asset_id,
'path': path
}
return asset_id
if media_type:
file_name = get_file_name(file_path)
file_format = get_file_format(file_path)
file_size = os.path.getsize(file_path)
created_at = datetime.now()
expires_at = created_at + timedelta(hours = 24)
asset =\
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'name': file_name,
'size': file_size,
'path': file_path
}
if media_type == 'audio':
asset.update(
{
'media': 'audio',
'format': cast(AudioFormat, file_format),
'metadata': extract_audio_metadata(file_path)
})
if media_type == 'image':
asset.update(
{
'media': 'image',
'format': cast(ImageFormat, file_format),
'metadata': extract_image_metadata(file_path)
})
if media_type == 'video':
asset.update(
{
'media': 'video',
'format': cast(VideoFormat, file_format),
'metadata': extract_video_metadata(file_path)
})
if session_id not in ASSET_STORE:
ASSET_STORE[session_id] = {}
ASSET_STORE[session_id][asset_id] = asset #type:ignore[assignment]
return asset #type:ignore[return-value]
return None
def clear() -> None:
ASSET_STORE.clear()
def detect_media_type(file_path : str) -> Optional[MediaType]:
if is_image(file_path):
return 'image'
if is_video(file_path):
return 'video'
if is_audio(file_path):
return 'audio'
return None
+10 -5
View File
@@ -2,8 +2,9 @@ from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND
from facefusion import args_store, state_manager, translator
from facefusion import args_store, session_manager, state_manager, translator
from facefusion.apis import asset_store
from facefusion.apis.endpoints.session import extract_access_token
async def get_state(request : Request) -> JSONResponse:
@@ -35,12 +36,14 @@ async def set_state(request : Request) -> JSONResponse:
async def select_source(request : Request) -> JSONResponse:
body = await request.json()
asset_ids = body.get('asset_ids')
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
if isinstance(asset_ids, list):
if isinstance(asset_ids, list) and session_id:
source_paths = []
for asset_id in asset_ids:
asset = asset_store.get_asset(asset_id)
asset = asset_store.get_asset(session_id, asset_id)
if asset:
source_paths.append(asset.get('path'))
@@ -59,9 +62,11 @@ async def select_source(request : Request) -> JSONResponse:
async def select_target(request : Request) -> JSONResponse:
body = await request.json()
asset_id = body.get('asset_id')
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
if isinstance(asset_id, str):
asset = asset_store.get_asset(asset_id)
if isinstance(asset_id, str) and session_id:
asset = asset_store.get_asset(session_id, asset_id)
if asset:
state_manager.set_item('target_path', asset.get('path'))
+62
View File
@@ -169,6 +169,68 @@ EncoderSet = TypedDict('EncoderSet',
})
VideoPreset = Literal['ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow']
AssetId : TypeAlias = str
AssetType = Literal['source', 'target']
MediaType = Literal['image', 'video', 'audio']
AudioMetadata = TypedDict('AudioMetadata',
{
'duration' : Duration
})
ImageMetadata = TypedDict('ImageMetadata',
{
'resolution' : Resolution
})
VideoMetadata = TypedDict('VideoMetadata',
{
'duration' : Duration,
'frame_total' : int,
'fps' : Fps,
'resolution' : Resolution
})
AudioAsset = TypedDict('AudioAsset',
{
'id' : AssetId,
'created_at' : datetime,
'expires_at' : datetime,
'type' : AssetType,
'media' : Literal['audio'],
'name' : str,
'format' : AudioFormat,
'size' : int,
'path' : str,
'metadata' : AudioMetadata
})
ImageAsset = TypedDict('ImageAsset',
{
'id' : AssetId,
'created_at' : datetime,
'expires_at' : datetime,
'type' : AssetType,
'media' : Literal['image'],
'name' : str,
'format' : ImageFormat,
'size' : int,
'path' : str,
'metadata' : ImageMetadata
})
VideoAsset = TypedDict('VideoAsset',
{
'id' : AssetId,
'created_at' : datetime,
'expires_at' : datetime,
'type' : AssetType,
'media' : Literal['video'],
'name' : str,
'format' : VideoFormat,
'size' : int,
'path' : str,
'metadata' : VideoMetadata
})
Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]]
BenchmarkMode = Literal['warm', 'cold']
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
BenchmarkSet : TypeAlias = Dict[BenchmarkResolution, str]
+44 -22
View File
@@ -1,3 +1,4 @@
import subprocess
from typing import Iterator
import pytest
@@ -6,6 +7,18 @@ from starlette.testclient import TestClient
from facefusion import args_store, metadata, session_manager, state_manager
from facefusion.apis import asset_store
from facefusion.apis.core import create_api
from facefusion.download import conditional_download
from .helper import get_test_example_file, get_test_examples_directory
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download(get_test_examples_directory(),
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ])
@pytest.fixture(scope = 'module')
@@ -85,10 +98,23 @@ def test_set_state(test_client : TestClient) -> None:
def test_select_source_assets(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
access_token = create_session_body.get('access_token')
session_id = session_manager.find_session_id(access_token)
source_paths =\
[
get_test_example_file('source.jpg'),
get_test_example_file('source.jpg')
]
asset_ids =\
[
asset_store.register_asset('/path/to/source1.jpg'),
asset_store.register_asset('/path/to/source2.jpg')
asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'),
asset_store.create_asset(session_id, 'source', source_paths[1]).get('id')
]
select_response = test_client.put('/state?action=select&type=source', json =
@@ -98,18 +124,12 @@ def test_select_source_assets(test_client : TestClient) -> None:
assert select_response.status_code == 401
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
select_response = test_client.put('/state?action=select&type=source', json =
{
'asset_ids': 'invalid'
}, headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
'Authorization': 'Bearer ' + access_token
})
assert select_response.status_code == 404
@@ -119,36 +139,38 @@ def test_select_source_assets(test_client : TestClient) -> None:
'asset_ids': asset_ids
}, headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
'Authorization': 'Bearer ' + access_token
})
select_body = select_response.json()
assert select_body.get('source_paths') == [ '/path/to/source1.jpg', '/path/to/source2.jpg' ]
assert select_body.get('source_paths') == source_paths
assert select_response.status_code == 200
def test_select_target_assets(test_client : TestClient) -> None:
asset_id = asset_store.register_asset('/path/to/target.jpg')
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
access_token = create_session_body.get('access_token')
session_id = session_manager.find_session_id(access_token)
target_path = get_test_example_file('target-240p.jpg')
asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id')
select_response = test_client.put('/state?action=select&type=target', json =
select_response = test_client.put('/state?action=select&type=target', json=
{
'asset_id': asset_id
})
assert select_response.status_code == 401
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
select_response = test_client.put('/state?action=select&type=target', json =
{
'asset_id': 'invalid'
}, headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
'Authorization': 'Bearer ' + access_token
})
assert select_response.status_code == 404
@@ -158,9 +180,9 @@ def test_select_target_assets(test_client : TestClient) -> None:
'asset_id': asset_id
}, headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
'Authorization': 'Bearer ' + access_token
})
select_body = select_response.json()
assert select_body.get('target_path') == '/path/to/target.jpg'
assert select_body.get('target_path') == target_path
assert select_response.status_code == 200