mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 17:36:16 +02:00
asset store update
This commit is contained in:
@@ -3,6 +3,10 @@ from typing import Optional
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import Scope
|
||||
|
||||
from facefusion.audio import detect_audio_duration
|
||||
from facefusion.types import AudioMetadata, ImageMetadata, VideoMetadata
|
||||
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
|
||||
|
||||
|
||||
def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
|
||||
protocol_header = Headers(scope = scope).get('Sec-WebSocket-Protocol')
|
||||
@@ -13,3 +17,29 @@ def get_sec_websocket_protocol(scope : Scope) -> Optional[str]:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_audio_metadata(file_path : str) -> AudioMetadata:
|
||||
metadata : AudioMetadata =\
|
||||
{
|
||||
'duration' : detect_audio_duration(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_image_metadata(file_path : str) -> ImageMetadata:
|
||||
metadata : ImageMetadata =\
|
||||
{
|
||||
'resolution' : detect_image_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_video_metadata(file_path : str) -> VideoMetadata:
|
||||
metadata : VideoMetadata =\
|
||||
{
|
||||
'duration' : detect_video_duration(file_path),
|
||||
'frame_total' : count_video_frame_total(file_path),
|
||||
'fps' : detect_video_fps(file_path),
|
||||
'resolution' : detect_video_resolution(file_path)
|
||||
}
|
||||
return metadata
|
||||
|
||||
@@ -1,23 +1,82 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
ASSET_STORE : Dict[str, Dict[str, str]] = {}
|
||||
from facefusion.apis.api_helper import extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name, is_audio, is_image, is_video
|
||||
from facefusion.types import Asset, AssetId, AssetStore, AssetType, AudioFormat, ImageFormat, MediaType, SessionId, VideoFormat
|
||||
|
||||
ASSET_STORE : AssetStore = {}
|
||||
|
||||
|
||||
def get_asset(asset_id : str) -> Optional[Dict[str, str]]:
|
||||
return ASSET_STORE.get(asset_id)
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[Asset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def register_asset(path : str) -> str:
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, file_path : str) -> Optional[Asset]:
|
||||
asset_id = str(uuid.uuid4())
|
||||
media_type = detect_media_type(file_path)
|
||||
|
||||
ASSET_STORE[asset_id] =\
|
||||
{
|
||||
'id': asset_id,
|
||||
'path': path
|
||||
}
|
||||
return asset_id
|
||||
if media_type:
|
||||
file_name = get_file_name(file_path)
|
||||
file_format = get_file_format(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours = 24)
|
||||
asset =\
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'name': file_name,
|
||||
'size': file_size,
|
||||
'path': file_path
|
||||
}
|
||||
|
||||
if media_type == 'audio':
|
||||
asset.update(
|
||||
{
|
||||
'media': 'audio',
|
||||
'format': cast(AudioFormat, file_format),
|
||||
'metadata': extract_audio_metadata(file_path)
|
||||
})
|
||||
if media_type == 'image':
|
||||
asset.update(
|
||||
{
|
||||
'media': 'image',
|
||||
'format': cast(ImageFormat, file_format),
|
||||
'metadata': extract_image_metadata(file_path)
|
||||
})
|
||||
if media_type == 'video':
|
||||
asset.update(
|
||||
{
|
||||
'media': 'video',
|
||||
'format': cast(VideoFormat, file_format),
|
||||
'metadata': extract_video_metadata(file_path)
|
||||
})
|
||||
|
||||
if session_id not in ASSET_STORE:
|
||||
ASSET_STORE[session_id] = {}
|
||||
|
||||
ASSET_STORE[session_id][asset_id] = asset #type:ignore[assignment]
|
||||
return asset #type:ignore[return-value]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
ASSET_STORE.clear()
|
||||
|
||||
|
||||
def detect_media_type(file_path : str) -> Optional[MediaType]:
|
||||
if is_image(file_path):
|
||||
return 'image'
|
||||
if is_video(file_path):
|
||||
return 'video'
|
||||
if is_audio(file_path):
|
||||
return 'audio'
|
||||
return None
|
||||
|
||||
@@ -2,8 +2,9 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import args_store, state_manager, translator
|
||||
from facefusion import args_store, session_manager, state_manager, translator
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.endpoints.session import extract_access_token
|
||||
|
||||
|
||||
async def get_state(request : Request) -> JSONResponse:
|
||||
@@ -35,12 +36,14 @@ async def set_state(request : Request) -> JSONResponse:
|
||||
async def select_source(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
asset_ids = body.get('asset_ids')
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if isinstance(asset_ids, list):
|
||||
if isinstance(asset_ids, list) and session_id:
|
||||
source_paths = []
|
||||
|
||||
for asset_id in asset_ids:
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
if asset:
|
||||
source_paths.append(asset.get('path'))
|
||||
@@ -59,9 +62,11 @@ async def select_source(request : Request) -> JSONResponse:
|
||||
async def select_target(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
asset_id = body.get('asset_id')
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if isinstance(asset_id, str):
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
if isinstance(asset_id, str) and session_id:
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
if asset:
|
||||
state_manager.set_item('target_path', asset.get('path'))
|
||||
|
||||
@@ -169,6 +169,68 @@ EncoderSet = TypedDict('EncoderSet',
|
||||
})
|
||||
VideoPreset = Literal['ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow']
|
||||
|
||||
AssetId : TypeAlias = str
|
||||
AssetType = Literal['source', 'target']
|
||||
MediaType = Literal['image', 'video', 'audio']
|
||||
AudioMetadata = TypedDict('AudioMetadata',
|
||||
{
|
||||
'duration' : Duration
|
||||
})
|
||||
ImageMetadata = TypedDict('ImageMetadata',
|
||||
{
|
||||
'resolution' : Resolution
|
||||
})
|
||||
VideoMetadata = TypedDict('VideoMetadata',
|
||||
{
|
||||
'duration' : Duration,
|
||||
'frame_total' : int,
|
||||
'fps' : Fps,
|
||||
'resolution' : Resolution
|
||||
})
|
||||
AudioAsset = TypedDict('AudioAsset',
|
||||
{
|
||||
'id' : AssetId,
|
||||
'created_at' : datetime,
|
||||
'expires_at' : datetime,
|
||||
'type' : AssetType,
|
||||
'media' : Literal['audio'],
|
||||
'name' : str,
|
||||
'format' : AudioFormat,
|
||||
'size' : int,
|
||||
'path' : str,
|
||||
'metadata' : AudioMetadata
|
||||
})
|
||||
ImageAsset = TypedDict('ImageAsset',
|
||||
{
|
||||
'id' : AssetId,
|
||||
'created_at' : datetime,
|
||||
'expires_at' : datetime,
|
||||
'type' : AssetType,
|
||||
'media' : Literal['image'],
|
||||
'name' : str,
|
||||
'format' : ImageFormat,
|
||||
'size' : int,
|
||||
'path' : str,
|
||||
'metadata' : ImageMetadata
|
||||
})
|
||||
VideoAsset = TypedDict('VideoAsset',
|
||||
{
|
||||
'id' : AssetId,
|
||||
'created_at' : datetime,
|
||||
'expires_at' : datetime,
|
||||
'type' : AssetType,
|
||||
'media' : Literal['video'],
|
||||
'name' : str,
|
||||
'format' : VideoFormat,
|
||||
'size' : int,
|
||||
'path' : str,
|
||||
'metadata' : VideoMetadata
|
||||
})
|
||||
|
||||
Asset : TypeAlias = AudioAsset | ImageAsset | VideoAsset
|
||||
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, Asset]]
|
||||
|
||||
BenchmarkMode = Literal['warm', 'cold']
|
||||
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
|
||||
BenchmarkSet : TypeAlias = Dict[BenchmarkResolution, str]
|
||||
|
||||
+44
-22
@@ -1,3 +1,4 @@
|
||||
import subprocess
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
@@ -6,6 +7,18 @@ from starlette.testclient import TestClient
|
||||
from facefusion import args_store, metadata, session_manager, state_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.core import create_api
|
||||
from facefusion.download import conditional_download
|
||||
from .helper import get_test_example_file, get_test_examples_directory
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
def before_all() -> None:
|
||||
conditional_download(get_test_examples_directory(),
|
||||
[
|
||||
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg',
|
||||
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4'
|
||||
])
|
||||
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ])
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module')
|
||||
@@ -85,10 +98,23 @@ def test_set_state(test_client : TestClient) -> None:
|
||||
|
||||
|
||||
def test_select_source_assets(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
|
||||
create_session_body = create_session_response.json()
|
||||
access_token = create_session_body.get('access_token')
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
source_paths =\
|
||||
[
|
||||
get_test_example_file('source.jpg'),
|
||||
get_test_example_file('source.jpg')
|
||||
]
|
||||
asset_ids =\
|
||||
[
|
||||
asset_store.register_asset('/path/to/source1.jpg'),
|
||||
asset_store.register_asset('/path/to/source2.jpg')
|
||||
asset_store.create_asset(session_id, 'source', source_paths[0]).get('id'),
|
||||
asset_store.create_asset(session_id, 'source', source_paths[1]).get('id')
|
||||
]
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=source', json =
|
||||
@@ -98,18 +124,12 @@ def test_select_source_assets(test_client : TestClient) -> None:
|
||||
|
||||
assert select_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=source', json =
|
||||
{
|
||||
'asset_ids': 'invalid'
|
||||
}, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
assert select_response.status_code == 404
|
||||
@@ -119,36 +139,38 @@ def test_select_source_assets(test_client : TestClient) -> None:
|
||||
'asset_ids': asset_ids
|
||||
}, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
select_body = select_response.json()
|
||||
|
||||
assert select_body.get('source_paths') == [ '/path/to/source1.jpg', '/path/to/source2.jpg' ]
|
||||
assert select_body.get('source_paths') == source_paths
|
||||
assert select_response.status_code == 200
|
||||
|
||||
|
||||
def test_select_target_assets(test_client : TestClient) -> None:
|
||||
asset_id = asset_store.register_asset('/path/to/target.jpg')
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
access_token = create_session_body.get('access_token')
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
target_path = get_test_example_file('target-240p.jpg')
|
||||
asset_id = asset_store.create_asset(session_id, 'target', target_path).get('id')
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=target', json =
|
||||
select_response = test_client.put('/state?action=select&type=target', json=
|
||||
{
|
||||
'asset_id': asset_id
|
||||
})
|
||||
|
||||
assert select_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
select_response = test_client.put('/state?action=select&type=target', json =
|
||||
{
|
||||
'asset_id': 'invalid'
|
||||
}, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
|
||||
assert select_response.status_code == 404
|
||||
@@ -158,9 +180,9 @@ def test_select_target_assets(test_client : TestClient) -> None:
|
||||
'asset_id': asset_id
|
||||
}, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
select_body = select_response.json()
|
||||
|
||||
assert select_body.get('target_path') == '/path/to/target.jpg'
|
||||
assert select_body.get('target_path') == target_path
|
||||
assert select_response.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user