mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 17:36:16 +02:00
Revamp the upload
This commit is contained in:
@@ -3,18 +3,19 @@ import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from facefusion.apis.asset_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name, is_file, remove_file
|
||||
from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, MediaType, SessionId, VideoAsset, VideoFormat
|
||||
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name
|
||||
from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
|
||||
|
||||
ASSET_STORE : AssetStore = {}
|
||||
|
||||
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, media_type : MediaType, file_path : str) -> Asset:
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
asset_id = str(uuid.uuid4())
|
||||
file_name = get_file_name(file_path)
|
||||
file_format = get_file_format(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
asset_name = get_file_name(asset_path)
|
||||
asset_format = get_file_format(asset_path)
|
||||
asset_size = os.path.getsize(asset_path)
|
||||
media_type = detect_media_type(asset_path)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours = 2)
|
||||
|
||||
@@ -29,11 +30,11 @@ def create_asset(session_id : SessionId, asset_type : AssetType, media_type : Me
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(AudioFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_audio_metadata(file_path)
|
||||
'name': asset_name,
|
||||
'format': cast(AudioFormat, asset_format),
|
||||
'size': asset_size,
|
||||
'path': asset_path,
|
||||
'metadata': extract_audio_metadata(asset_path)
|
||||
})
|
||||
|
||||
if media_type == 'image':
|
||||
@@ -44,11 +45,11 @@ def create_asset(session_id : SessionId, asset_type : AssetType, media_type : Me
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(ImageFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_image_metadata(file_path)
|
||||
'name': asset_name,
|
||||
'format': cast(ImageFormat, asset_format),
|
||||
'size': asset_size,
|
||||
'path': asset_path,
|
||||
'metadata': extract_image_metadata(asset_path)
|
||||
})
|
||||
|
||||
if media_type == 'video':
|
||||
@@ -59,33 +60,21 @@ def create_asset(session_id : SessionId, asset_type : AssetType, media_type : Me
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(VideoFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_video_metadata(file_path)
|
||||
'name': asset_name,
|
||||
'format': cast(VideoFormat, asset_format),
|
||||
'size': asset_size,
|
||||
'path': asset_path,
|
||||
'metadata': extract_video_metadata(asset_path)
|
||||
})
|
||||
|
||||
return ASSET_STORE[session_id][asset_id]
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
|
||||
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]:
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def clear_session(session_id : SessionId) -> None:
|
||||
if session_id in ASSET_STORE:
|
||||
for asset in ASSET_STORE[session_id].values():
|
||||
file_path = asset.get('path')
|
||||
|
||||
if is_file(file_path):
|
||||
remove_file(file_path)
|
||||
|
||||
del ASSET_STORE[session_id]
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
for session_id in list(ASSET_STORE.keys()):
|
||||
clear_session(session_id)
|
||||
ASSET_STORE.clear()
|
||||
|
||||
@@ -3,71 +3,58 @@ from typing import List
|
||||
|
||||
from starlette.datastructures import UploadFile
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
|
||||
|
||||
from facefusion import session_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.asset_helper import detect_media_type
|
||||
from facefusion.apis.endpoints.session import extract_access_token
|
||||
from facefusion.filesystem import get_file_extension, is_file, remove_file
|
||||
from facefusion.filesystem import get_file_extension, remove_file
|
||||
|
||||
|
||||
async def upload_asset(request : Request) -> JSONResponse:
|
||||
async def upload_asset(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
asset_type = request.query_params.get('type')
|
||||
|
||||
if session_id and asset_type in [ 'source', 'target' ]:
|
||||
form = await request.form()
|
||||
files = [ file for file in form.getlist('file') if isinstance(file, UploadFile) ]
|
||||
upload_files = form.getlist('file')
|
||||
asset_paths = await save_asset_files(upload_files)
|
||||
|
||||
if asset_type == 'target':
|
||||
files = files[:1]
|
||||
|
||||
media_files = await prepare_media_files(files)
|
||||
|
||||
if media_files:
|
||||
if asset_paths:
|
||||
asset_ids : List[str] = []
|
||||
|
||||
for media_file in media_files:
|
||||
media_type = detect_media_type(media_file)
|
||||
for asset_path in asset_paths:
|
||||
asset = asset_store.create_asset(session_id, asset_type, asset_path) # type: ignore[arg-type]
|
||||
asset_id = asset.get('id')
|
||||
|
||||
if media_type:
|
||||
asset = asset_store.create_asset(session_id, asset_type, media_type, media_file) #type:ignore[arg-type]
|
||||
|
||||
if asset:
|
||||
asset_id = asset.get('id')
|
||||
|
||||
if asset_id:
|
||||
asset_ids.append(asset_id)
|
||||
if asset_id:
|
||||
asset_ids.append(asset_id)
|
||||
|
||||
if asset_ids:
|
||||
if asset_type == 'target':
|
||||
return JSONResponse({ 'asset_id': asset_ids[0] }, status_code = HTTP_201_CREATED)
|
||||
return JSONResponse(
|
||||
{
|
||||
'asset_ids': asset_ids
|
||||
}, status_code = HTTP_201_CREATED)
|
||||
|
||||
return JSONResponse({ 'asset_ids': asset_ids }, status_code = HTTP_201_CREATED)
|
||||
|
||||
return JSONResponse({}, status_code = HTTP_400_BAD_REQUEST)
|
||||
return Response(status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
async def prepare_media_files(files : List[UploadFile]) -> List[str]:
|
||||
media_files : List[str] = []
|
||||
async def save_asset_files(upload_files : List[UploadFile]) -> List[str]:
|
||||
asset_paths : List[str] = []
|
||||
|
||||
for file in files:
|
||||
file_extension = get_file_extension(file.filename)
|
||||
for upload_file in upload_files:
|
||||
upload_file_extension = get_file_extension(upload_file.filename)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix = file_extension, delete = False) as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
file_path = temp_file.name
|
||||
with tempfile.NamedTemporaryFile(suffix = upload_file_extension, delete = False) as temp_file:
|
||||
temp_content = await upload_file.read()
|
||||
temp_file.write(temp_content)
|
||||
|
||||
media_type = detect_media_type(file_path)
|
||||
if detect_media_type(temp_file.name):
|
||||
asset_paths.append(temp_file.name)
|
||||
else:
|
||||
remove_file(temp_file.name)
|
||||
|
||||
if media_type:
|
||||
media_files.append(file_path)
|
||||
|
||||
if not media_type and is_file(file_path):
|
||||
remove_file(file_path)
|
||||
|
||||
return media_files
|
||||
return asset_paths
|
||||
|
||||
+1
-2
@@ -232,8 +232,7 @@ VideoAsset = TypedDict('VideoAsset',
|
||||
})
|
||||
|
||||
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
|
||||
Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]]
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]]
|
||||
|
||||
BenchmarkMode = Literal['warm', 'cold']
|
||||
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
|
||||
|
||||
@@ -114,7 +114,7 @@ def predict_video_frame_total(video_path : str, fps : Fps, trim_frame_start : in
|
||||
return 0
|
||||
|
||||
|
||||
def detect_video_fps(video_path : str) -> Optional[float]:
|
||||
def detect_video_fps(video_path : str) -> Optional[Fps]:
|
||||
if is_video(video_path):
|
||||
video_capture = get_video_capture(video_path)
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ def test_upload_target_asset(test_client : TestClient) -> None:
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 201
|
||||
assert upload_response.json().get('asset_id')
|
||||
assert upload_response.json().get('asset_ids')
|
||||
|
||||
|
||||
def test_upload_unsupported_format(test_client : TestClient) -> None:
|
||||
@@ -143,7 +143,7 @@ def test_upload_unsupported_format(test_client : TestClient) -> None:
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
}, files =
|
||||
{
|
||||
'file': ('test.txt', b'invalid content', 'text/plain')
|
||||
'file': ('test.txt', b'invalid', 'text/plain')
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 400
|
||||
|
||||
@@ -113,8 +113,8 @@ def test_select_source_assets(test_client : TestClient) -> None:
|
||||
]
|
||||
asset_ids =\
|
||||
[
|
||||
asset_store.create_asset(session_id, 'source', 'image', source_paths[0]).get('id'),
|
||||
asset_store.create_asset(session_id, 'source', 'image', source_paths[1]).get('id')
|
||||
asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'),
|
||||
asset_store.create_asset(session_id, 'source', source_paths[1]).get('id')
|
||||
]
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=source', json =
|
||||
@@ -156,7 +156,7 @@ def test_select_target_assets(test_client : TestClient) -> None:
|
||||
access_token = create_session_body.get('access_token')
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
target_path = get_test_example_file('target-240p.jpg')
|
||||
asset_id = asset_store.create_asset(session_id, 'target', 'image', target_path).get('id')
|
||||
asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id')
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=target', json=
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user