mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-30 05:17:49 +02:00
upload asset endpoint
This commit is contained in:
@@ -1,20 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.audio import detect_audio_duration
|
||||
from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate
|
||||
from facefusion.filesystem import is_audio, is_image, is_video
|
||||
from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata
|
||||
from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata
|
||||
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
|
||||
|
||||
|
||||
def extract_audio_metadata(file_path : str) -> AudioMetadata:
|
||||
metadata : AudioMetadata =\
|
||||
{
|
||||
'duration' : detect_audio_duration(file_path),
|
||||
'sample_rate' : detect_audio_sample_rate(file_path),
|
||||
'frame_total' : detect_audio_frame_total(file_path),
|
||||
'channels' : detect_audio_channel_total(file_path),
|
||||
'format' : detect_audio_format(file_path)
|
||||
'duration': detect_audio_duration(file_path),
|
||||
'sample_rate': detect_audio_sample_rate(file_path),
|
||||
'frame_total': detect_audio_frame_total(file_path),
|
||||
'channels': detect_audio_channel_total(file_path),
|
||||
'format': detect_audio_format(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
@@ -22,7 +19,7 @@ def extract_audio_metadata(file_path : str) -> AudioMetadata:
|
||||
def extract_image_metadata(file_path : str) -> ImageMetadata:
|
||||
metadata : ImageMetadata =\
|
||||
{
|
||||
'resolution' : detect_image_resolution(file_path)
|
||||
'resolution': detect_image_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
@@ -30,19 +27,9 @@ def extract_image_metadata(file_path : str) -> ImageMetadata:
|
||||
def extract_video_metadata(file_path : str) -> VideoMetadata:
|
||||
metadata : VideoMetadata =\
|
||||
{
|
||||
'duration' : detect_video_duration(file_path),
|
||||
'frame_total' : count_video_frame_total(file_path),
|
||||
'fps' : detect_video_fps(file_path),
|
||||
'resolution' : detect_video_resolution(file_path)
|
||||
'duration': detect_video_duration(file_path),
|
||||
'frame_total': count_video_frame_total(file_path),
|
||||
'fps': detect_video_fps(file_path),
|
||||
'resolution': detect_video_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def detect_media_type(file_path : str) -> Optional[MediaType]:
|
||||
if is_audio(file_path):
|
||||
return 'audio'
|
||||
if is_image(file_path):
|
||||
return 'image'
|
||||
if is_video(file_path):
|
||||
return 'video'
|
||||
return None
|
||||
|
||||
@@ -3,82 +3,89 @@ import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name
|
||||
from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
|
||||
from facefusion.apis.asset_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name, is_file, remove_file
|
||||
from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, MediaType, SessionId, VideoAsset, VideoFormat
|
||||
|
||||
ASSET_STORE : AssetStore = {}
|
||||
|
||||
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, media_type : MediaType, file_path : str) -> Asset:
|
||||
asset_id = str(uuid.uuid4())
|
||||
media_type = detect_media_type(file_path)
|
||||
file_name = get_file_name(file_path)
|
||||
file_format = get_file_format(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours = 2)
|
||||
|
||||
if media_type:
|
||||
file_name = get_file_name(file_path)
|
||||
file_format = get_file_format(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours = 2)
|
||||
if session_id not in ASSET_STORE:
|
||||
ASSET_STORE[session_id] = {}
|
||||
|
||||
if session_id not in ASSET_STORE:
|
||||
ASSET_STORE[session_id] = {}
|
||||
if media_type == 'audio':
|
||||
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(AudioFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_audio_metadata(file_path)
|
||||
})
|
||||
|
||||
if media_type == 'audio':
|
||||
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(AudioFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_audio_metadata(file_path)
|
||||
})
|
||||
if media_type == 'image':
|
||||
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(ImageFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_image_metadata(file_path)
|
||||
})
|
||||
|
||||
if media_type == 'image':
|
||||
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(ImageFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_image_metadata(file_path)
|
||||
})
|
||||
if media_type == 'video':
|
||||
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(VideoFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_video_metadata(file_path)
|
||||
})
|
||||
|
||||
if media_type == 'video':
|
||||
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(VideoFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_video_metadata(file_path)
|
||||
})
|
||||
|
||||
return ASSET_STORE[session_id][asset_id]
|
||||
|
||||
return None
|
||||
return ASSET_STORE[session_id][asset_id]
|
||||
|
||||
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def clear_session(session_id : SessionId) -> None:
|
||||
if session_id in ASSET_STORE:
|
||||
for asset in ASSET_STORE[session_id].values():
|
||||
file_path = asset.get('path')
|
||||
|
||||
if file_path and is_file(file_path):
|
||||
remove_file(file_path)
|
||||
|
||||
del ASSET_STORE[session_id]
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
ASSET_STORE.clear()
|
||||
for session_id in list(ASSET_STORE.keys()):
|
||||
clear_session(session_id)
|
||||
|
||||
@@ -3,6 +3,7 @@ from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Route, WebSocketRoute
|
||||
|
||||
from facefusion.apis.endpoints.assets import upload_asset
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
@@ -18,6 +19,7 @@ def create_api() -> Starlette:
|
||||
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]),
|
||||
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import tempfile
|
||||
from typing import List, Tuple
|
||||
|
||||
from starlette.datastructures import UploadFile
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
|
||||
|
||||
from facefusion import session_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.endpoints.session import extract_access_token
|
||||
from facefusion.filesystem import get_file_extension, is_audio, is_file, is_image, is_video, remove_file
|
||||
from facefusion.types import MediaType
|
||||
|
||||
|
||||
async def upload_asset(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
asset_type = request.query_params.get('type')
|
||||
|
||||
if session_id and asset_type in [ 'source', 'target' ]:
|
||||
form = await request.form()
|
||||
files = [ file for file in form.getlist('file') if isinstance(file, UploadFile) ]
|
||||
|
||||
if asset_type == 'target':
|
||||
files = files[:1]
|
||||
|
||||
prepared_files = await prepare_files(files)
|
||||
|
||||
if prepared_files:
|
||||
asset_ids : List[str] = []
|
||||
|
||||
for file_path, media_type in prepared_files:
|
||||
asset = asset_store.create_asset(session_id, asset_type, media_type, file_path) #type:ignore[arg-type]
|
||||
|
||||
if asset:
|
||||
asset_id = asset.get('id')
|
||||
|
||||
if asset_id:
|
||||
asset_ids.append(asset_id)
|
||||
|
||||
if asset_ids:
|
||||
if asset_type == 'target':
|
||||
return JSONResponse({ 'asset_id': asset_ids[0] }, status_code = HTTP_201_CREATED)
|
||||
|
||||
return JSONResponse({ 'asset_ids': asset_ids }, status_code = HTTP_201_CREATED)
|
||||
|
||||
return JSONResponse({}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
async def prepare_files(files : List[UploadFile]) -> List[Tuple[str, MediaType]]:
|
||||
prepared_files : List[Tuple[str, MediaType]] = []
|
||||
|
||||
for file in files:
|
||||
file_extension = get_file_extension(file.filename)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix = file_extension, delete = False) as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
file_path = temp_file.name
|
||||
|
||||
if is_audio(file_path):
|
||||
prepared_files.append((file_path, 'audio'))
|
||||
continue
|
||||
if is_image(file_path):
|
||||
prepared_files.append((file_path, 'image'))
|
||||
continue
|
||||
if is_video(file_path):
|
||||
prepared_files.append((file_path, 'video'))
|
||||
continue
|
||||
|
||||
if is_file(file_path):
|
||||
remove_file(file_path)
|
||||
|
||||
return prepared_files
|
||||
+2
-1
@@ -232,7 +232,8 @@ VideoAsset = TypedDict('VideoAsset',
|
||||
})
|
||||
|
||||
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]]
|
||||
Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]]
|
||||
|
||||
BenchmarkMode = Literal['warm', 'cold']
|
||||
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, session_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.core import create_api
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
def test_client() -> Iterator[TestClient]:
|
||||
with TestClient(create_api()) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function', autouse = True)
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
asset_store.clear()
|
||||
|
||||
|
||||
def test_upload_asset_without_auth(test_client : TestClient) -> None:
|
||||
upload_response = test_client.post('/assets?type=source')
|
||||
|
||||
assert upload_response.status_code == 401
|
||||
|
||||
|
||||
def test_upload_asset_invalid_type(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
upload_response = test_client.post('/assets?type=invalid', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 400
|
||||
|
||||
|
||||
def test_upload_asset_no_file(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
upload_response = test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 400
|
||||
|
||||
|
||||
def test_upload_source_asset(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
with open('.assets/examples/source.jpg', 'rb') as source_file:
|
||||
upload_response = test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
}, files =
|
||||
{
|
||||
'file': ('source.jpg', source_file, 'image/jpeg')
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 201
|
||||
assert upload_response.json().get('asset_ids')
|
||||
assert len(upload_response.json().get('asset_ids')) == 1
|
||||
|
||||
|
||||
def test_upload_multiple_source_assets(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
with open('.assets/examples/source.jpg', 'rb') as source_file_1:
|
||||
with open('.assets/examples/source.jpg', 'rb') as source_file_2:
|
||||
upload_response = test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
}, files =
|
||||
[
|
||||
('file', ('source1.jpg', source_file_1, 'image/jpeg')),
|
||||
('file', ('source2.jpg', source_file_2, 'image/jpeg'))
|
||||
])
|
||||
|
||||
assert upload_response.status_code == 201
|
||||
assert upload_response.json().get('asset_ids')
|
||||
assert len(upload_response.json().get('asset_ids')) == 2
|
||||
|
||||
|
||||
def test_upload_target_asset(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
with open('.assets/examples/target-240p.mp4', 'rb') as target_file:
|
||||
upload_response = test_client.post('/assets?type=target', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
}, files =
|
||||
{
|
||||
'file': ('target.mp4', target_file, 'video/mp4')
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 201
|
||||
assert upload_response.json().get('asset_id')
|
||||
|
||||
|
||||
def test_upload_target_multiple_files_uses_first(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
with open('.assets/examples/target-240p.mp4', 'rb') as target_file_1:
|
||||
with open('.assets/examples/target-240p.mp4', 'rb') as target_file_2:
|
||||
upload_response = test_client.post('/assets?type=target', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
}, files =
|
||||
[
|
||||
('file', ('target1.mp4', target_file_1, 'video/mp4')),
|
||||
('file', ('target2.mp4', target_file_2, 'video/mp4'))
|
||||
])
|
||||
|
||||
assert upload_response.status_code == 201
|
||||
assert upload_response.json().get('asset_id')
|
||||
|
||||
|
||||
def test_upload_unsupported_format(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
upload_response = test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
}, files =
|
||||
{
|
||||
'file': ('test.txt', b'invalid content', 'text/plain')
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 400
|
||||
@@ -113,8 +113,8 @@ def test_select_source_assets(test_client : TestClient) -> None:
|
||||
]
|
||||
asset_ids =\
|
||||
[
|
||||
asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'),
|
||||
asset_store.create_asset(session_id, 'source', source_paths[1]).get('id')
|
||||
asset_store.create_asset(session_id, 'source', 'image', source_paths[0]).get('id'),
|
||||
asset_store.create_asset(session_id, 'source', 'image', source_paths[1]).get('id')
|
||||
]
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=source', json =
|
||||
@@ -156,7 +156,7 @@ def test_select_target_assets(test_client : TestClient) -> None:
|
||||
access_token = create_session_body.get('access_token')
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
target_path = get_test_example_file('target-240p.jpg')
|
||||
asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id')
|
||||
asset_id = asset_store.create_asset(session_id, 'target', 'image', target_path).get('id')
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=target', json=
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user