upload asset endpoint

This commit is contained in:
harisreedhar
2026-01-16 17:05:17 +05:30
parent 0d86679c67
commit a09c078c90
7 changed files with 321 additions and 89 deletions
+11 -24
View File
@@ -1,20 +1,17 @@
from typing import Optional
from facefusion.audio import detect_audio_duration
from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate
from facefusion.filesystem import is_audio, is_image, is_video
from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata
from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
def extract_audio_metadata(file_path : str) -> AudioMetadata:
metadata : AudioMetadata =\
{
'duration' : detect_audio_duration(file_path),
'sample_rate' : detect_audio_sample_rate(file_path),
'frame_total' : detect_audio_frame_total(file_path),
'channels' : detect_audio_channel_total(file_path),
'format' : detect_audio_format(file_path)
'duration': detect_audio_duration(file_path),
'sample_rate': detect_audio_sample_rate(file_path),
'frame_total': detect_audio_frame_total(file_path),
'channels': detect_audio_channel_total(file_path),
'format': detect_audio_format(file_path)
}
return metadata
@@ -22,7 +19,7 @@ def extract_audio_metadata(file_path : str) -> AudioMetadata:
def extract_image_metadata(file_path : str) -> ImageMetadata:
metadata : ImageMetadata =\
{
'resolution' : detect_image_resolution(file_path)
'resolution': detect_image_resolution(file_path)
}
return metadata
@@ -30,19 +27,9 @@ def extract_image_metadata(file_path : str) -> ImageMetadata:
def extract_video_metadata(file_path : str) -> VideoMetadata:
metadata : VideoMetadata =\
{
'duration' : detect_video_duration(file_path),
'frame_total' : count_video_frame_total(file_path),
'fps' : detect_video_fps(file_path),
'resolution' : detect_video_resolution(file_path)
'duration': detect_video_duration(file_path),
'frame_total': count_video_frame_total(file_path),
'fps': detect_video_fps(file_path),
'resolution': detect_video_resolution(file_path)
}
return metadata
def detect_media_type(file_path : str) -> Optional[MediaType]:
if is_audio(file_path):
return 'audio'
if is_image(file_path):
return 'image'
if is_video(file_path):
return 'video'
return None
+68 -61
View File
@@ -3,82 +3,89 @@ import uuid
from datetime import datetime, timedelta
from typing import Optional, cast
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
from facefusion.filesystem import get_file_format, get_file_name
from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
from facefusion.apis.asset_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata
from facefusion.filesystem import get_file_format, get_file_name, is_file, remove_file
from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, MediaType, SessionId, VideoAsset, VideoFormat
ASSET_STORE : AssetStore = {}
def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
def create_asset(session_id : SessionId, asset_type : AssetType, media_type : MediaType, file_path : str) -> Asset:
asset_id = str(uuid.uuid4())
media_type = detect_media_type(file_path)
file_name = get_file_name(file_path)
file_format = get_file_format(file_path)
file_size = os.path.getsize(file_path)
created_at = datetime.now()
expires_at = created_at + timedelta(hours = 2)
if media_type:
file_name = get_file_name(file_path)
file_format = get_file_format(file_path)
file_size = os.path.getsize(file_path)
created_at = datetime.now()
expires_at = created_at + timedelta(hours = 2)
if session_id not in ASSET_STORE:
ASSET_STORE[session_id] = {}
if session_id not in ASSET_STORE:
ASSET_STORE[session_id] = {}
if media_type == 'audio':
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': file_name,
'format': cast(AudioFormat, file_format),
'size': file_size,
'path': file_path,
'metadata': extract_audio_metadata(file_path)
})
if media_type == 'audio':
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': file_name,
'format': cast(AudioFormat, file_format),
'size': file_size,
'path': file_path,
'metadata': extract_audio_metadata(file_path)
})
if media_type == 'image':
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': file_name,
'format': cast(ImageFormat, file_format),
'size': file_size,
'path': file_path,
'metadata': extract_image_metadata(file_path)
})
if media_type == 'image':
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': file_name,
'format': cast(ImageFormat, file_format),
'size': file_size,
'path': file_path,
'metadata': extract_image_metadata(file_path)
})
if media_type == 'video':
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': file_name,
'format': cast(VideoFormat, file_format),
'size': file_size,
'path': file_path,
'metadata': extract_video_metadata(file_path)
})
if media_type == 'video':
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': file_name,
'format': cast(VideoFormat, file_format),
'size': file_size,
'path': file_path,
'metadata': extract_video_metadata(file_path)
})
return ASSET_STORE[session_id][asset_id]
return None
return ASSET_STORE[session_id][asset_id]
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]:
if session_id in ASSET_STORE:
return ASSET_STORE[session_id].get(asset_id)
return None
def clear_session(session_id : SessionId) -> None:
if session_id in ASSET_STORE:
for asset in ASSET_STORE[session_id].values():
file_path = asset.get('path')
if file_path and is_file(file_path):
remove_file(file_path)
del ASSET_STORE[session_id]
def clear() -> None:
ASSET_STORE.clear()
for session_id in list(ASSET_STORE.keys()):
clear_session(session_id)
+2
View File
@@ -3,6 +3,7 @@ from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route, WebSocketRoute
from facefusion.apis.endpoints.assets import upload_asset
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
@@ -18,6 +19,7 @@ def create_api() -> Starlette:
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]),
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]),
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
]
+75
View File
@@ -0,0 +1,75 @@
import tempfile
from typing import List, Tuple
from starlette.datastructures import UploadFile
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
from facefusion import session_manager
from facefusion.apis import asset_store
from facefusion.apis.endpoints.session import extract_access_token
from facefusion.filesystem import get_file_extension, is_audio, is_file, is_image, is_video, remove_file
from facefusion.types import MediaType
async def upload_asset(request : Request) -> JSONResponse:
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
asset_type = request.query_params.get('type')
if session_id and asset_type in [ 'source', 'target' ]:
form = await request.form()
files = [ file for file in form.getlist('file') if isinstance(file, UploadFile) ]
if asset_type == 'target':
files = files[:1]
prepared_files = await prepare_files(files)
if prepared_files:
asset_ids : List[str] = []
for file_path, media_type in prepared_files:
asset = asset_store.create_asset(session_id, asset_type, media_type, file_path) #type:ignore[arg-type]
if asset:
asset_id = asset.get('id')
if asset_id:
asset_ids.append(asset_id)
if asset_ids:
if asset_type == 'target':
return JSONResponse({ 'asset_id': asset_ids[0] }, status_code = HTTP_201_CREATED)
return JSONResponse({ 'asset_ids': asset_ids }, status_code = HTTP_201_CREATED)
return JSONResponse({}, status_code = HTTP_400_BAD_REQUEST)
async def prepare_files(files : List[UploadFile]) -> List[Tuple[str, MediaType]]:
prepared_files : List[Tuple[str, MediaType]] = []
for file in files:
file_extension = get_file_extension(file.filename)
with tempfile.NamedTemporaryFile(suffix = file_extension, delete = False) as temp_file:
content = await file.read()
temp_file.write(content)
file_path = temp_file.name
if is_audio(file_path):
prepared_files.append((file_path, 'audio'))
continue
if is_image(file_path):
prepared_files.append((file_path, 'image'))
continue
if is_video(file_path):
prepared_files.append((file_path, 'video'))
continue
if is_file(file_path):
remove_file(file_path)
return prepared_files
+2 -1
View File
@@ -232,7 +232,8 @@ VideoAsset = TypedDict('VideoAsset',
})
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]]
Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]]
BenchmarkMode = Literal['warm', 'cold']
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
+160
View File
@@ -0,0 +1,160 @@
from typing import Iterator
import pytest
from starlette.testclient import TestClient
from facefusion import metadata, session_manager
from facefusion.apis import asset_store
from facefusion.apis.core import create_api
@pytest.fixture(scope = 'module')
def test_client() -> Iterator[TestClient]:
with TestClient(create_api()) as test_client:
yield test_client
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
session_manager.SESSIONS.clear()
asset_store.clear()
def test_upload_asset_without_auth(test_client : TestClient) -> None:
upload_response = test_client.post('/assets?type=source')
assert upload_response.status_code == 401
def test_upload_asset_invalid_type(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
upload_response = test_client.post('/assets?type=invalid', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
assert upload_response.status_code == 400
def test_upload_asset_no_file(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
upload_response = test_client.post('/assets?type=source', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
assert upload_response.status_code == 400
def test_upload_source_asset(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
with open('.assets/examples/source.jpg', 'rb') as source_file:
upload_response = test_client.post('/assets?type=source', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
}, files =
{
'file': ('source.jpg', source_file, 'image/jpeg')
})
assert upload_response.status_code == 201
assert upload_response.json().get('asset_ids')
assert len(upload_response.json().get('asset_ids')) == 1
def test_upload_multiple_source_assets(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
with open('.assets/examples/source.jpg', 'rb') as source_file_1:
with open('.assets/examples/source.jpg', 'rb') as source_file_2:
upload_response = test_client.post('/assets?type=source', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
}, files =
[
('file', ('source1.jpg', source_file_1, 'image/jpeg')),
('file', ('source2.jpg', source_file_2, 'image/jpeg'))
])
assert upload_response.status_code == 201
assert upload_response.json().get('asset_ids')
assert len(upload_response.json().get('asset_ids')) == 2
def test_upload_target_asset(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
with open('.assets/examples/target-240p.mp4', 'rb') as target_file:
upload_response = test_client.post('/assets?type=target', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
}, files =
{
'file': ('target.mp4', target_file, 'video/mp4')
})
assert upload_response.status_code == 201
assert upload_response.json().get('asset_id')
def test_upload_target_multiple_files_uses_first(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
with open('.assets/examples/target-240p.mp4', 'rb') as target_file_1:
with open('.assets/examples/target-240p.mp4', 'rb') as target_file_2:
upload_response = test_client.post('/assets?type=target', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
}, files =
[
('file', ('target1.mp4', target_file_1, 'video/mp4')),
('file', ('target2.mp4', target_file_2, 'video/mp4'))
])
assert upload_response.status_code == 201
assert upload_response.json().get('asset_id')
def test_upload_unsupported_format(test_client : TestClient) -> None:
create_session_response = test_client.post('/session', json =
{
'client_version': metadata.get('version')
})
create_session_body = create_session_response.json()
upload_response = test_client.post('/assets?type=source', headers =
{
'Authorization': 'Bearer ' + create_session_body.get('access_token')
}, files =
{
'file': ('test.txt', b'invalid content', 'text/plain')
})
assert upload_response.status_code == 400
+3 -3
View File
@@ -113,8 +113,8 @@ def test_select_source_assets(test_client : TestClient) -> None:
]
asset_ids =\
[
asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'),
asset_store.create_asset(session_id, 'source', source_paths[1]).get('id')
asset_store.create_asset(session_id, 'source', 'image', source_paths[0]).get('id'),
asset_store.create_asset(session_id, 'source', 'image', source_paths[1]).get('id')
]
select_response = test_client.put('/state?action=select&type=source', json =
@@ -156,7 +156,7 @@ def test_select_target_assets(test_client : TestClient) -> None:
access_token = create_session_body.get('access_token')
session_id = session_manager.find_session_id(access_token)
target_path = get_test_example_file('target-240p.jpg')
asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id')
asset_id = asset_store.create_asset(session_id, 'target', 'image', target_path).get('id')
select_response = test_client.put('/state?action=select&type=target', json=
{