diff --git a/facefusion/apis/asset_store.py b/facefusion/apis/asset_store.py index 4e587986..72d8b4bc 100644 --- a/facefusion/apis/asset_store.py +++ b/facefusion/apis/asset_store.py @@ -3,18 +3,19 @@ import uuid from datetime import datetime, timedelta from typing import Optional, cast -from facefusion.apis.asset_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata -from facefusion.filesystem import get_file_format, get_file_name, is_file, remove_file -from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, MediaType, SessionId, VideoAsset, VideoFormat +from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata +from facefusion.filesystem import get_file_format, get_file_name +from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat ASSET_STORE : AssetStore = {} -def create_asset(session_id : SessionId, asset_type : AssetType, media_type : MediaType, file_path : str) -> Asset: +def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]: asset_id = str(uuid.uuid4()) - file_name = get_file_name(file_path) - file_format = get_file_format(file_path) - file_size = os.path.getsize(file_path) + asset_name = get_file_name(asset_path) + asset_format = get_file_format(asset_path) + asset_size = os.path.getsize(asset_path) + media_type = detect_media_type(asset_path) created_at = datetime.now() expires_at = created_at + timedelta(hours = 2) @@ -29,11 +30,11 @@ def create_asset(session_id : SessionId, asset_type : AssetType, media_type : Me 'expires_at': expires_at, 'type': asset_type, 'media': media_type, - 'name': file_name, - 'format': cast(AudioFormat, file_format), - 'size': file_size, - 'path': file_path, - 'metadata': extract_audio_metadata(file_path) + 'name': asset_name, + 'format': cast(AudioFormat, asset_format), + 'size': asset_size, + 'path': asset_path, + 'metadata': extract_audio_metadata(asset_path) }) if media_type == 'image': @@ -44,11 +45,11 @@ def create_asset(session_id : SessionId, asset_type : AssetType, media_type : Me 'expires_at': expires_at, 'type': asset_type, 'media': media_type, - 'name': file_name, - 'format': cast(ImageFormat, file_format), - 'size': file_size, - 'path': file_path, - 'metadata': extract_image_metadata(file_path) + 'name': asset_name, + 'format': cast(ImageFormat, asset_format), + 'size': asset_size, + 'path': asset_path, + 'metadata': extract_image_metadata(asset_path) }) if media_type == 'video': @@ -59,33 +60,21 @@ def create_asset(session_id : SessionId, asset_type : AssetType, media_type : Me 'expires_at': expires_at, 'type': asset_type, 'media': media_type, - 'name': file_name, - 'format': cast(VideoFormat, file_format), - 'size': file_size, - 'path': file_path, - 'metadata': extract_video_metadata(file_path) + 'name': asset_name, + 'format': cast(VideoFormat, asset_format), + 'size': asset_size, + 'path': asset_path, + 'metadata': extract_video_metadata(asset_path) }) - return ASSET_STORE[session_id][asset_id] + return ASSET_STORE[session_id].get(asset_id) -def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]: +def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]: if session_id in ASSET_STORE: return ASSET_STORE[session_id].get(asset_id) return None -def clear_session(session_id : SessionId) -> None: - if session_id in ASSET_STORE: - for asset in ASSET_STORE[session_id].values(): - file_path = asset.get('path') - - if is_file(file_path): - remove_file(file_path) - - del ASSET_STORE[session_id] - - def clear() -> None: - for session_id in list(ASSET_STORE.keys()): - clear_session(session_id) + ASSET_STORE.clear() diff --git a/facefusion/apis/endpoints/assets.py b/facefusion/apis/endpoints/assets.py index f030633d..b4daa113 100644 --- a/facefusion/apis/endpoints/assets.py +++ b/facefusion/apis/endpoints/assets.py @@ -3,71 +3,58 @@ from typing import List from starlette.datastructures import UploadFile from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, Response from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST from facefusion import session_manager from facefusion.apis import asset_store from facefusion.apis.asset_helper import detect_media_type from facefusion.apis.endpoints.session import extract_access_token -from facefusion.filesystem import get_file_extension, is_file, remove_file +from facefusion.filesystem import get_file_extension, remove_file -async def upload_asset(request : Request) -> JSONResponse: +async def upload_asset(request : Request) -> Response: access_token = extract_access_token(request.scope) session_id = session_manager.find_session_id(access_token) asset_type = request.query_params.get('type') if session_id and asset_type in [ 'source', 'target' ]: form = await request.form() - files = [ file for file in form.getlist('file') if isinstance(file, UploadFile) ] + upload_files = form.getlist('file') + asset_paths = await save_asset_files(upload_files) - if asset_type == 'target': - files = files[:1] - - media_files = await prepare_media_files(files) - - if media_files: + if asset_paths: asset_ids : List[str] = [] - for media_file in media_files: - media_type = detect_media_type(media_file) + for asset_path in asset_paths: + asset = asset_store.create_asset(session_id, asset_type, asset_path) # type: ignore[arg-type] + asset_id = asset.get('id') - if media_type: - asset = asset_store.create_asset(session_id, asset_type, media_type, media_file) #type:ignore[arg-type] - - if asset: - asset_id = asset.get('id') - - if asset_id: - asset_ids.append(asset_id) + if asset_id: + asset_ids.append(asset_id) if asset_ids: - if asset_type == 'target': - return JSONResponse({ 'asset_id': asset_ids[0] }, status_code = HTTP_201_CREATED) + return JSONResponse( + { + 'asset_ids': asset_ids + }, status_code = HTTP_201_CREATED) - return JSONResponse({ 'asset_ids': asset_ids }, status_code = HTTP_201_CREATED) - - return JSONResponse({}, status_code = HTTP_400_BAD_REQUEST) + return Response(status_code = HTTP_400_BAD_REQUEST) -async def prepare_media_files(files : List[UploadFile]) -> List[str]: - media_files : List[str] = [] +async def save_asset_files(upload_files : List[UploadFile]) -> List[str]: + asset_paths : List[str] = [] - for file in files: - file_extension = get_file_extension(file.filename) + for upload_file in upload_files: + upload_file_extension = get_file_extension(upload_file.filename) - with tempfile.NamedTemporaryFile(suffix = file_extension, delete = False) as temp_file: - content = await file.read() - temp_file.write(content) - file_path = temp_file.name + with tempfile.NamedTemporaryFile(suffix = upload_file_extension, delete = False) as temp_file: + temp_content = await upload_file.read() + temp_file.write(temp_content) - media_type = detect_media_type(file_path) + if detect_media_type(temp_file.name): + asset_paths.append(temp_file.name) + else: + remove_file(temp_file.name) - if media_type: - media_files.append(file_path) - - if not media_type and is_file(file_path): - remove_file(file_path) - - return media_files + return asset_paths diff --git a/facefusion/types.py b/facefusion/types.py index 36216e94..d3cc82ce 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -232,8 +232,7 @@ VideoAsset = TypedDict('VideoAsset', }) AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata -Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset -AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]] +AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]] BenchmarkMode = Literal['warm', 'cold'] BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p'] diff --git a/facefusion/vision.py b/facefusion/vision.py index 3a427c3a..42291f72 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -114,7 +114,7 @@ def predict_video_frame_total(video_path : str, fps : Fps, trim_frame_start : in return 0 -def detect_video_fps(video_path : str) -> Optional[float]: +def detect_video_fps(video_path : str) -> Optional[Fps]: if is_video(video_path): video_capture = get_video_capture(video_path) diff --git a/tests/test_api_assets.py b/tests/test_api_assets.py index 7483a884..bdf568ad 100644 --- a/tests/test_api_assets.py +++ b/tests/test_api_assets.py @@ -128,7 +128,7 @@ def test_upload_target_asset(test_client : TestClient) -> None: }) assert upload_response.status_code == 201 - assert upload_response.json().get('asset_id') + assert upload_response.json().get('asset_ids') def test_upload_unsupported_format(test_client : TestClient) -> None: @@ -143,7 +143,7 @@ def test_upload_unsupported_format(test_client : TestClient) -> None: 'Authorization': 'Bearer ' + create_session_body.get('access_token') }, files = { - 'file': ('test.txt', b'invalid content', 'text/plain') + 'file': ('test.txt', b'invalid', 'text/plain') }) assert upload_response.status_code == 400 diff --git a/tests/test_api_state.py b/tests/test_api_state.py index f8c34113..1a6f857d 100644 --- a/tests/test_api_state.py +++ b/tests/test_api_state.py @@ -113,8 +113,8 @@ def test_select_source_assets(test_client : TestClient) -> None: ] asset_ids =\ [ - asset_store.create_asset(session_id, 'source', 'image', source_paths[0]).get('id'), - asset_store.create_asset(session_id, 'source', 'image', source_paths[1]).get('id') + asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'), + asset_store.create_asset(session_id, 'source', source_paths[1]).get('id') ] select_response = test_client.put('/state?action=select&type=source', json = @@ -156,7 +156,7 @@ def test_select_target_assets(test_client : TestClient) -> None: access_token = create_session_body.get('access_token') session_id = session_manager.find_session_id(access_token) target_path = get_test_example_file('target-240p.jpg') - asset_id = asset_store.create_asset(session_id, 'target', 'image', target_path).get('id') + asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id') select_response = test_client.put('/state?action=select&type=target', json= {