From e35038bc58091f6b8816fbb7b82dfa390c50ea55 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 15 Jan 2026 09:35:41 +0100 Subject: [PATCH] Polish asset store and helpers --- facefusion/apis/api_helper.py | 36 ---------- facefusion/apis/asset_helper.py | 48 ++++++++++++++ facefusion/apis/asset_store.py | 112 ++++++++++++++++---------------- facefusion/types.py | 3 +- tests/test_ffprobe.py | 46 ++++--------- 5 files changed, 119 insertions(+), 126 deletions(-) create mode 100644 facefusion/apis/asset_helper.py diff --git a/facefusion/apis/api_helper.py b/facefusion/apis/api_helper.py index ed8d7e03..1c3cbc73 100644 --- a/facefusion/apis/api_helper.py +++ b/facefusion/apis/api_helper.py @@ -3,11 +3,6 @@ from typing import Optional from starlette.datastructures import Headers from starlette.types import Scope -from facefusion.audio import detect_audio_duration -from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate -from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata -from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution - def get_sec_websocket_protocol(scope : Scope) -> Optional[str]: protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol') @@ -17,34 +12,3 @@ def get_sec_websocket_protocol(scope : Scope) -> Optional[str]: return protocol.strip() return None - - -def extract_audio_metadata(file_path : str) -> AudioMetadata: - metadata : AudioMetadata =\ - { - 'duration' : detect_audio_duration(file_path), - 'sample_rate' : detect_audio_sample_rate(file_path), - 'frame_total' : detect_audio_frame_total(file_path), - 'channels' : detect_audio_channel_total(file_path), - 'format' : detect_audio_format(file_path) - } - return metadata - - -def extract_image_metadata(file_path : str) -> ImageMetadata: - metadata : ImageMetadata =\ - { - 'resolution' : detect_image_resolution(file_path) - } - return metadata - - -def extract_video_metadata(file_path : str) -> VideoMetadata: - metadata : VideoMetadata =\ - { - 'duration' : detect_video_duration(file_path), - 'frame_total' : count_video_frame_total(file_path), - 'fps' : detect_video_fps(file_path), - 'resolution' : detect_video_resolution(file_path) - } - return metadata diff --git a/facefusion/apis/asset_helper.py b/facefusion/apis/asset_helper.py new file mode 100644 index 00000000..979d6d83 --- /dev/null +++ b/facefusion/apis/asset_helper.py @@ -0,0 +1,48 @@ +from typing import Optional + +from facefusion.audio import detect_audio_duration +from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate +from facefusion.filesystem import is_audio, is_image, is_video +from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata +from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution + + +def extract_audio_metadata(file_path : str) -> AudioMetadata: + metadata : AudioMetadata =\ + { + 'duration' : detect_audio_duration(file_path), + 'sample_rate' : detect_audio_sample_rate(file_path), + 'frame_total' : detect_audio_frame_total(file_path), + 'channels' : detect_audio_channel_total(file_path), + 'format' : detect_audio_format(file_path) + } + return metadata + + +def extract_image_metadata(file_path : str) -> ImageMetadata: + metadata : ImageMetadata =\ + { + 'resolution' : detect_image_resolution(file_path) + } + return metadata + + +def extract_video_metadata(file_path : str) -> VideoMetadata: + metadata : VideoMetadata =\ + { + 'duration' : detect_video_duration(file_path), + 'frame_total' : count_video_frame_total(file_path), + 'fps' : detect_video_fps(file_path), + 'resolution' : detect_video_resolution(file_path) + } + return metadata + + +def detect_media_type(file_path : str) -> Optional[MediaType]: + if is_audio(file_path): + return 'audio' + if is_image(file_path): + return 'image' + if is_video(file_path): + return 'video' + return None diff --git a/facefusion/apis/asset_store.py b/facefusion/apis/asset_store.py index c437dc73..66520e13 100644 --- a/facefusion/apis/asset_store.py +++ b/facefusion/apis/asset_store.py @@ -3,20 +3,14 @@ import uuid from datetime import datetime, timedelta from typing import Optional, cast -from facefusion.apis.api_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata -from facefusion.filesystem import get_file_format, get_file_name, is_audio, is_image, is_video -from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioFormat, ImageFormat, MediaType, SessionId, VideoFormat +from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata +from facefusion.filesystem import get_file_format, get_file_name +from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat ASSET_STORE : AssetStore = {} -def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]: - if session_id in ASSET_STORE: - return ASSET_STORE[session_id].get(asset_id) - return None - - -def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[Asset]: +def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]: asset_id = str(uuid.uuid4()) media_type = detect_media_type(file_path) @@ -25,58 +19,66 @@ def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str file_format = get_file_format(file_path) file_size = os.path.getsize(file_path) created_at = datetime.now() - expires_at = created_at + timedelta(hours = 24) - asset =\ - { - 'id': asset_id, - 'created_at': created_at, - 'expires_at': expires_at, - 'type': asset_type, - 'name': file_name, - 'size': file_size, - 'path': file_path - } - - if media_type == 'audio': - asset.update( - { - 'media': 'audio', - 'format': cast(AudioFormat, file_format), - 'metadata': extract_audio_metadata(file_path) - }) - if media_type == 'image': - asset.update( - { - 'media': 'image', - 'format': cast(ImageFormat, file_format), - 'metadata': extract_image_metadata(file_path) - }) - if media_type == 'video': - asset.update( - { - 'media': 'video', - 'format': cast(VideoFormat, file_format), - 'metadata': extract_video_metadata(file_path) - }) + expires_at = created_at + timedelta(hours = 2) if session_id not in ASSET_STORE: ASSET_STORE[session_id] = {} - ASSET_STORE[session_id][asset_id] = asset #type:ignore[assignment] - return asset #type:ignore[return-value] + if media_type == 'audio': + ASSET_STORE[session_id][asset_id] = cast(AudioAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': file_name, + 'format': cast(AudioFormat, file_format), + 'size': file_size, + 'path': file_path, + 'metadata': extract_audio_metadata(file_path) + }) + if media_type == 'image': + ASSET_STORE[session_id][asset_id] = cast(ImageAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': file_name, + 'format': cast(ImageFormat, file_format), + 'size': file_size, + 'path': file_path, + 'metadata': extract_image_metadata(file_path) + }) + + if media_type == 'video': + ASSET_STORE[session_id][asset_id] = cast(VideoAsset, + { + 'id': asset_id, + 'created_at': created_at, + 'expires_at': expires_at, + 'type': asset_type, + 'media': media_type, + 'name': file_name, + 'format': cast(VideoFormat, file_format), + 'size': file_size, + 'path': file_path, + 'metadata': extract_video_metadata(file_path) + }) + + return ASSET_STORE[session_id][asset_id] + + return None + + +def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]: + if session_id in ASSET_STORE: + return ASSET_STORE[session_id].get(asset_id) return None def clear() -> None: ASSET_STORE.clear() - - -def detect_media_type(file_path : str) -> Optional[MediaType]: - if is_image(file_path): - return 'image' - if is_video(file_path): - return 'video' - if is_audio(file_path): - return 'audio' - return None diff --git a/facefusion/types.py b/facefusion/types.py index bdac9304..cee45f81 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -231,9 +231,8 @@ VideoAsset = TypedDict('VideoAsset', 'metadata' : VideoMetadata }) -Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata -AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]] +AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]] BenchmarkMode = Literal['warm', 'cold'] BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p'] diff --git a/tests/test_ffprobe.py b/tests/test_ffprobe.py index a7cc9f7b..87f1a9fd 100644 --- a/tests/test_ffprobe.py +++ b/tests/test_ffprobe.py @@ -15,48 +15,28 @@ def before_all() -> None: [ 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3' ]) - subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), get_test_example_file('source.wav') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), '-t', '1.9', '-ar', '48000', '-ac', '2', get_test_example_file('source.wav') ]) def test_detect_audio_sample_rate() -> None: - audio_sample_rate = detect_audio_sample_rate(get_test_example_file('source.mp3')) - assert audio_sample_rate == 44100 - - audio_sample_rate = detect_audio_sample_rate(get_test_example_file('source.wav')) - assert audio_sample_rate == 44100 - - audio_sample_rate = detect_audio_sample_rate(get_test_example_file('invalid.mp3')) - assert audio_sample_rate is None + assert detect_audio_sample_rate(get_test_example_file('source.mp3')) == 44100 + assert detect_audio_sample_rate(get_test_example_file('source.wav')) == 44100 + assert detect_audio_sample_rate(get_test_example_file('invalid.mp3')) is None def test_detect_audio_channel_total() -> None: - audio_channel_total = detect_audio_channel_total(get_test_example_file('source.mp3')) - assert audio_channel_total == 1 - - audio_channel_total = detect_audio_channel_total(get_test_example_file('source.wav')) - assert audio_channel_total == 1 - - audio_channel_total = detect_audio_channel_total(get_test_example_file('invalid.mp3')) - assert audio_channel_total is None + assert detect_audio_channel_total(get_test_example_file('source.mp3')) == 1 + assert detect_audio_channel_total(get_test_example_file('source.wav')) == 1 + assert detect_audio_channel_total(get_test_example_file('invalid.mp3')) is None def test_detect_audio_frame_total() -> None: - audio_frame_total = detect_audio_frame_total(get_test_example_file('source.mp3')) - assert audio_frame_total == 167039 - - audio_frame_total = detect_audio_frame_total(get_test_example_file('source.wav')) - assert audio_frame_total == 167039 - - audio_frame_total = detect_audio_frame_total(get_test_example_file('invalid.mp3')) - assert audio_frame_total is None + assert detect_audio_frame_total(get_test_example_file('source.mp3')) == 167039 + assert detect_audio_frame_total(get_test_example_file('source.wav')) == 167039 + assert detect_audio_frame_total(get_test_example_file('invalid.mp3')) is None def test_detect_audio_format() -> None: - audio_format = detect_audio_format(get_test_example_file('source.mp3')) - assert audio_format == 'mp3' - - audio_format = detect_audio_format(get_test_example_file('source.wav')) - assert audio_format == 'pcm_s16le' - - audio_format = detect_audio_format(get_test_example_file('invalid.mp3')) - assert audio_format is None + assert detect_audio_format(get_test_example_file('source.mp3')) == 'mp3' + assert detect_audio_format(get_test_example_file('source.wav')) == 'pcm_s16le' + assert detect_audio_format(get_test_example_file('invalid.mp3')) is None