mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 13:05:59 +02:00
Polish asset store and helpers
This commit is contained in:
@@ -3,11 +3,6 @@ from typing import Optional
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import Scope
|
||||
|
||||
from facefusion.audio import detect_audio_duration
|
||||
from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate
|
||||
from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata
|
||||
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
|
||||
|
||||
|
||||
def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
|
||||
protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol')
|
||||
@@ -17,34 +12,3 @@ def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
|
||||
return protocol.strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_audio_metadata(file_path : str) -> AudioMetadata:
|
||||
metadata : AudioMetadata =\
|
||||
{
|
||||
'duration' : detect_audio_duration(file_path),
|
||||
'sample_rate' : detect_audio_sample_rate(file_path),
|
||||
'frame_total' : detect_audio_frame_total(file_path),
|
||||
'channels' : detect_audio_channel_total(file_path),
|
||||
'format' : detect_audio_format(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_image_metadata(file_path : str) -> ImageMetadata:
|
||||
metadata : ImageMetadata =\
|
||||
{
|
||||
'resolution' : detect_image_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_video_metadata(file_path : str) -> VideoMetadata:
|
||||
metadata : VideoMetadata =\
|
||||
{
|
||||
'duration' : detect_video_duration(file_path),
|
||||
'frame_total' : count_video_frame_total(file_path),
|
||||
'fps' : detect_video_fps(file_path),
|
||||
'resolution' : detect_video_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.audio import detect_audio_duration
|
||||
from facefusion.ffprobe import detect_audio_channel_total, detect_audio_format, detect_audio_frame_total, detect_audio_sample_rate
|
||||
from facefusion.filesystem import is_audio, is_image, is_video
|
||||
from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata
|
||||
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
|
||||
|
||||
|
||||
def extract_audio_metadata(file_path : str) -> AudioMetadata:
|
||||
metadata : AudioMetadata =\
|
||||
{
|
||||
'duration' : detect_audio_duration(file_path),
|
||||
'sample_rate' : detect_audio_sample_rate(file_path),
|
||||
'frame_total' : detect_audio_frame_total(file_path),
|
||||
'channels' : detect_audio_channel_total(file_path),
|
||||
'format' : detect_audio_format(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_image_metadata(file_path : str) -> ImageMetadata:
|
||||
metadata : ImageMetadata =\
|
||||
{
|
||||
'resolution' : detect_image_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_video_metadata(file_path : str) -> VideoMetadata:
|
||||
metadata : VideoMetadata =\
|
||||
{
|
||||
'duration' : detect_video_duration(file_path),
|
||||
'frame_total' : count_video_frame_total(file_path),
|
||||
'fps' : detect_video_fps(file_path),
|
||||
'resolution' : detect_video_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def detect_media_type(file_path : str) -> Optional[MediaType]:
|
||||
if is_audio(file_path):
|
||||
return 'audio'
|
||||
if is_image(file_path):
|
||||
return 'image'
|
||||
if is_video(file_path):
|
||||
return 'video'
|
||||
return None
|
||||
@@ -3,20 +3,14 @@ import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from facefusion.apis.api_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name, is_audio, is_image, is_video
|
||||
from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioFormat, ImageFormat, MediaType, SessionId, VideoFormat
|
||||
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name
|
||||
from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
|
||||
|
||||
ASSET_STORE : AssetStore = {}
|
||||
|
||||
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[Asset]:
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
asset_id = str(uuid.uuid4())
|
||||
media_type = detect_media_type(file_path)
|
||||
|
||||
@@ -25,58 +19,66 @@ def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str
|
||||
file_format = get_file_format(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours = 24)
|
||||
asset =\
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'name': file_name,
|
||||
'size': file_size,
|
||||
'path': file_path
|
||||
}
|
||||
|
||||
if media_type == 'audio':
|
||||
asset.update(
|
||||
{
|
||||
'media': 'audio',
|
||||
'format': cast(AudioFormat, file_format),
|
||||
'metadata': extract_audio_metadata(file_path)
|
||||
})
|
||||
if media_type == 'image':
|
||||
asset.update(
|
||||
{
|
||||
'media': 'image',
|
||||
'format': cast(ImageFormat, file_format),
|
||||
'metadata': extract_image_metadata(file_path)
|
||||
})
|
||||
if media_type == 'video':
|
||||
asset.update(
|
||||
{
|
||||
'media': 'video',
|
||||
'format': cast(VideoFormat, file_format),
|
||||
'metadata': extract_video_metadata(file_path)
|
||||
})
|
||||
expires_at = created_at + timedelta(hours = 2)
|
||||
|
||||
if session_id not in ASSET_STORE:
|
||||
ASSET_STORE[session_id] = {}
|
||||
|
||||
ASSET_STORE[session_id][asset_id] = asset #type:ignore[assignment]
|
||||
return asset #type:ignore[return-value]
|
||||
if media_type == 'audio':
|
||||
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(AudioFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_audio_metadata(file_path)
|
||||
})
|
||||
|
||||
if media_type == 'image':
|
||||
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(ImageFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_image_metadata(file_path)
|
||||
})
|
||||
|
||||
if media_type == 'video':
|
||||
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': file_name,
|
||||
'format': cast(VideoFormat, file_format),
|
||||
'size': file_size,
|
||||
'path': file_path,
|
||||
'metadata': extract_video_metadata(file_path)
|
||||
})
|
||||
|
||||
return ASSET_STORE[session_id][asset_id]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
ASSET_STORE.clear()
|
||||
|
||||
|
||||
def detect_media_type(file_path : str) -> Optional[MediaType]:
|
||||
if is_image(file_path):
|
||||
return 'image'
|
||||
if is_video(file_path):
|
||||
return 'video'
|
||||
if is_audio(file_path):
|
||||
return 'audio'
|
||||
return None
|
||||
|
||||
+1
-2
@@ -231,9 +231,8 @@ VideoAsset = TypedDict('VideoAsset',
|
||||
'metadata' : VideoMetadata
|
||||
})
|
||||
|
||||
Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset
|
||||
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]]
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]]
|
||||
|
||||
BenchmarkMode = Literal['warm', 'cold']
|
||||
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
|
||||
|
||||
+13
-33
@@ -15,48 +15,28 @@ def before_all() -> None:
|
||||
[
|
||||
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3'
|
||||
])
|
||||
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), get_test_example_file('source.wav') ])
|
||||
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), '-t', '1.9', '-ar', '48000', '-ac', '2', get_test_example_file('source.wav') ])
|
||||
|
||||
|
||||
def test_detect_audio_sample_rate() -> None:
|
||||
audio_sample_rate = detect_audio_sample_rate(get_test_example_file('source.mp3'))
|
||||
assert audio_sample_rate == 44100
|
||||
|
||||
audio_sample_rate = detect_audio_sample_rate(get_test_example_file('source.wav'))
|
||||
assert audio_sample_rate == 44100
|
||||
|
||||
audio_sample_rate = detect_audio_sample_rate(get_test_example_file('invalid.mp3'))
|
||||
assert audio_sample_rate is None
|
||||
assert detect_audio_sample_rate(get_test_example_file('source.mp3')) == 44100
|
||||
assert detect_audio_sample_rate(get_test_example_file('source.wav')) == 44100
|
||||
assert detect_audio_sample_rate(get_test_example_file('invalid.mp3')) is None
|
||||
|
||||
|
||||
def test_detect_audio_channel_total() -> None:
|
||||
audio_channel_total = detect_audio_channel_total(get_test_example_file('source.mp3'))
|
||||
assert audio_channel_total == 1
|
||||
|
||||
audio_channel_total = detect_audio_channel_total(get_test_example_file('source.wav'))
|
||||
assert audio_channel_total == 1
|
||||
|
||||
audio_channel_total = detect_audio_channel_total(get_test_example_file('invalid.mp3'))
|
||||
assert audio_channel_total is None
|
||||
assert detect_audio_channel_total(get_test_example_file('source.mp3')) == 1
|
||||
assert detect_audio_channel_total(get_test_example_file('source.wav')) == 1
|
||||
assert detect_audio_channel_total(get_test_example_file('invalid.mp3')) is None
|
||||
|
||||
|
||||
def test_detect_audio_frame_total() -> None:
|
||||
audio_frame_total = detect_audio_frame_total(get_test_example_file('source.mp3'))
|
||||
assert audio_frame_total == 167039
|
||||
|
||||
audio_frame_total = detect_audio_frame_total(get_test_example_file('source.wav'))
|
||||
assert audio_frame_total == 167039
|
||||
|
||||
audio_frame_total = detect_audio_frame_total(get_test_example_file('invalid.mp3'))
|
||||
assert audio_frame_total is None
|
||||
assert detect_audio_frame_total(get_test_example_file('source.mp3')) == 167039
|
||||
assert detect_audio_frame_total(get_test_example_file('source.wav')) == 167039
|
||||
assert detect_audio_frame_total(get_test_example_file('invalid.mp3')) is None
|
||||
|
||||
|
||||
def test_detect_audio_format() -> None:
|
||||
audio_format = detect_audio_format(get_test_example_file('source.mp3'))
|
||||
assert audio_format == 'mp3'
|
||||
|
||||
audio_format = detect_audio_format(get_test_example_file('source.wav'))
|
||||
assert audio_format == 'pcm_s16le'
|
||||
|
||||
audio_format = detect_audio_format(get_test_example_file('invalid.mp3'))
|
||||
assert audio_format is None
|
||||
assert detect_audio_format(get_test_example_file('source.mp3')) == 'mp3'
|
||||
assert detect_audio_format(get_test_example_file('source.wav')) == 'pcm_s16le'
|
||||
assert detect_audio_format(get_test_example_file('invalid.mp3')) is None
|
||||
|
||||
Reference in New Issue
Block a user