diff --git a/facefusion/apis/asset_helper.py b/facefusion/apis/asset_helper.py index 80989ef0..2538ee57 100644 --- a/facefusion/apis/asset_helper.py +++ b/facefusion/apis/asset_helper.py @@ -1,6 +1,9 @@ +from typing import Optional + from facefusion.audio import detect_audio_duration from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate -from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata +from facefusion.filesystem import is_audio, is_image, is_video +from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution @@ -33,3 +36,13 @@ def extract_video_metadata(file_path : str) -> VideoMetadata: 'resolution': detect_video_resolution(file_path) } return metadata + + +def detect_media_type(file_path : str) -> Optional[MediaType]: + if is_audio(file_path): + return 'audio' + if is_image(file_path): + return 'image' + if is_video(file_path): + return 'video' + return None diff --git a/facefusion/apis/endpoints/assets.py b/facefusion/apis/endpoints/assets.py index 7c319dc1..f030633d 100644 --- a/facefusion/apis/endpoints/assets.py +++ b/facefusion/apis/endpoints/assets.py @@ -1,5 +1,5 @@ import tempfile -from typing import List, Tuple +from typing import List from starlette.datastructures import UploadFile from starlette.requests import Request @@ -8,9 +8,9 @@ from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST from facefusion import session_manager from facefusion.apis import asset_store +from facefusion.apis.asset_helper import detect_media_type from facefusion.apis.endpoints.session import extract_access_token -from facefusion.filesystem import get_file_extension, is_audio, is_file, is_image, is_video, remove_file -from facefusion.types import MediaType +from facefusion.filesystem import get_file_extension, is_file, remove_file async def upload_asset(request : Request) -> JSONResponse: @@ -30,14 +30,17 @@ async def upload_asset(request : Request) -> JSONResponse: if media_files: asset_ids : List[str] = [] - for file_path, media_type in media_files: - asset = asset_store.create_asset(session_id, asset_type, media_type, file_path) #type:ignore[arg-type] + for media_file in media_files: + media_type = detect_media_type(media_file) - if asset: - asset_id = asset.get('id') + if media_type: + asset = asset_store.create_asset(session_id, asset_type, media_type, media_file) #type:ignore[arg-type] - if asset_id: - asset_ids.append(asset_id) + if asset: + asset_id = asset.get('id') + + if asset_id: + asset_ids.append(asset_id) if asset_ids: if asset_type == 'target': @@ -48,8 +51,8 @@ async def upload_asset(request : Request) -> JSONResponse: return JSONResponse({}, status_code = HTTP_400_BAD_REQUEST) -async def prepare_media_files(files : List[UploadFile]) -> List[Tuple[str, MediaType]]: - media_files : List[Tuple[str, MediaType]] = [] +async def prepare_media_files(files : List[UploadFile]) -> List[str]: + media_files : List[str] = [] for file in files: file_extension = get_file_extension(file.filename) @@ -59,17 +62,10 @@ async def prepare_media_files(files : List[UploadFile]) -> List[Tuple[str, Media temp_file.write(content) file_path = temp_file.name - media_type : MediaType | None = None - - if is_audio(file_path): - media_type = 'audio' - if is_image(file_path): - media_type = 'image' - if is_video(file_path): - media_type = 'video' + media_type = detect_media_type(file_path) if media_type: - media_files.append((file_path, media_type)) + media_files.append(file_path) if not media_type and is_file(file_path): remove_file(file_path)