refactor api

This commit is contained in:
harisreedhar
2026-01-22 16:05:32 +05:30
parent ed1c9b0b24
commit b7b60c186f
15 changed files with 596 additions and 464 deletions
+50
View File
@@ -0,0 +1,50 @@
from typing import Optional
from facefusion.audio import detect_audio_duration
from facefusion.filesystem import is_audio, is_image, is_video
from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
def extract_audio_metadata(file_path : str) -> AudioMetadata:
metadata : AudioMetadata =\
{
'duration': detect_audio_duration(file_path),
'sample_rate': 0,
'frame_total': 0,
'channels': 0,
'format': ''
}
return metadata
def extract_image_metadata(file_path : str) -> ImageMetadata:
resolution = detect_image_resolution(file_path)
metadata : ImageMetadata =\
{
'resolution': resolution if resolution else (0, 0)
}
return metadata
def extract_video_metadata(file_path : str) -> VideoMetadata:
resolution = detect_video_resolution(file_path)
fps = detect_video_fps(file_path)
metadata : VideoMetadata =\
{
'duration': detect_video_duration(file_path),
'frame_total': count_video_frame_total(file_path),
'fps': fps if fps else 0.0,
'resolution': resolution if resolution else (0, 0)
}
return metadata
def detect_media_type(file_path : str) -> Optional[MediaType]:
if is_audio(file_path):
return 'audio'
if is_image(file_path):
return 'image'
if is_video(file_path):
return 'video'
return None
+92
View File
@@ -0,0 +1,92 @@
import os
import uuid
from datetime import datetime, timedelta
from typing import List, Optional, cast
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
from facefusion.filesystem import get_file_format, get_file_name
from facefusion.types import AssetId, AssetSet, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
ASSET_STORE : AssetStore = {}
def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
asset_id = str(uuid.uuid4())
asset_name = get_file_name(asset_path)
asset_format = get_file_format(asset_path)
asset_size = os.path.getsize(asset_path)
media_type = detect_media_type(asset_path)
created_at = datetime.now()
expires_at = created_at + timedelta(hours = 2)
if session_id not in ASSET_STORE:
ASSET_STORE[session_id] = {}
if media_type == 'audio':
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': asset_name,
'format': cast(AudioFormat, asset_format),
'size': asset_size,
'path': asset_path,
'metadata': extract_audio_metadata(asset_path)
})
if media_type == 'image':
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': asset_name,
'format': cast(ImageFormat, asset_format),
'size': asset_size,
'path': asset_path,
'metadata': extract_image_metadata(asset_path)
})
if media_type == 'video':
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
{
'id': asset_id,
'created_at': created_at,
'expires_at': expires_at,
'type': asset_type,
'media': media_type,
'name': asset_name,
'format': cast(VideoFormat, asset_format),
'size': asset_size,
'path': asset_path,
'metadata': extract_video_metadata(asset_path)
})
return ASSET_STORE[session_id].get(asset_id)
def get_assets(session_id : SessionId) -> Optional[AssetSet]:
return ASSET_STORE.get(session_id)
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
if session_id in ASSET_STORE:
return ASSET_STORE.get(session_id).get(asset_id)
return None
def delete_assets(session_id : SessionId, asset_ids : List[AssetId]) -> None:
if session_id in ASSET_STORE:
for asset_id in asset_ids:
if asset_id in ASSET_STORE.get(session_id):
del ASSET_STORE[session_id][asset_id]
return None
def clear() -> None:
ASSET_STORE.clear()
-186
View File
@@ -1,186 +0,0 @@
import os
import tempfile
from starlette.requests import Request
from starlette.responses import FileResponse, JSONResponse
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
from facefusion import asset_store, filesystem, logger
from facefusion.vision import count_video_frame_total, detect_video_fps, detect_video_resolution
async def upload_asset(request : Request) -> JSONResponse:
asset_type = request.query_params.get('type')
if not asset_type:
return JSONResponse({'message': 'Missing required query parameter: type'}, status_code = HTTP_400_BAD_REQUEST)
if asset_type not in ['source', 'target']:
return JSONResponse({'message': 'Invalid type. Must be "source" or "target"'}, status_code = HTTP_400_BAD_REQUEST)
form = await request.form()
if asset_type == 'source':
files = form.getlist('file')
if not files:
return JSONResponse({'message': 'No file provided'}, status_code = HTTP_400_BAD_REQUEST)
asset_ids = []
for file in files:
filename = file.filename if hasattr(file, 'filename') else 'source.jpg'
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as temp_file:
content = await file.read()
temp_file.write(content)
file_path = temp_file.name
if not (filesystem.is_image(file_path) or filesystem.is_video(file_path) or filesystem.is_audio(file_path)):
if os.path.exists(file_path):
os.remove(file_path)
return JSONResponse(
{
'message': 'Unsupported file format. Allowed formats - Images: bmp, jpeg, png, tiff, webp. Videos: avi, m4v, mkv, mov, mp4, mpeg, mxf, webm, wmv.'
},
status_code = HTTP_400_BAD_REQUEST
)
asset_id = asset_store.register('source', file_path, filename)
asset_ids.append(asset_id)
logger.debug(f'Uploaded {len(asset_ids)} source(s)', __name__)
return JSONResponse(
{
'message': f'{len(asset_ids)} source(s) uploaded successfully',
'asset_ids': asset_ids
},
status_code = HTTP_201_CREATED
)
if asset_type == 'target':
file = form.get('file')
if not file:
return JSONResponse({'message': 'No file provided'}, status_code = HTTP_400_BAD_REQUEST)
if isinstance(file, str):
return JSONResponse({'message': 'Expected file upload, got string. Use /stream/target for URLs'}, status_code = HTTP_400_BAD_REQUEST)
if not hasattr(file, 'filename'):
return JSONResponse({'message': 'Invalid file object'}, status_code = HTTP_400_BAD_REQUEST)
filename = file.filename
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as temp_file:
content = await file.read()
temp_file.write(content)
file_path = temp_file.name
if not (filesystem.is_image(file_path) or filesystem.is_video(file_path) or filesystem.is_audio(file_path)):
if os.path.exists(file_path):
os.remove(file_path)
return JSONResponse(
{
'message': 'Unsupported file format. Allowed formats - Images: bmp, jpeg, png, tiff, webp. Videos: avi, m4v, mkv, mov, mp4, mpeg, mxf, webm, wmv.'
},
status_code = HTTP_400_BAD_REQUEST
)
metadata = None
if filesystem.is_video(file_path):
frame_total = count_video_frame_total(file_path)
fps = detect_video_fps(file_path)
resolution = detect_video_resolution(file_path)
metadata =\
{
'frame_total': frame_total,
'fps': fps,
'resolution': resolution
}
logger.debug(f'Video metadata - frames: {frame_total}, fps: {fps}, resolution: {resolution}', __name__)
asset_id = asset_store.register('target', file_path, filename, metadata)
logger.debug(f'Target uploaded with asset_id: {asset_id}', __name__)
return JSONResponse(
{
'message': 'Target uploaded successfully',
'asset_id': asset_id
},
status_code = HTTP_201_CREATED
)
async def list_all_assets(request : Request) -> JSONResponse:
asset_type = request.query_params.get('type')
media_type = request.query_params.get('media_type')
format = request.query_params.get('format')
assets = asset_store.list_assets(asset_type)
if media_type:
assets = [a for a in assets if a.get('media_type') == media_type]
if format:
assets = [a for a in assets if a.get('format') == format]
safe_assets = []
for asset in assets:
safe_asset = {k: v for k, v in asset.items() if k != 'path'}
safe_assets.append(safe_asset)
return JSONResponse({'assets': safe_assets, 'count': len(safe_assets)}, status_code = HTTP_200_OK)
async def get_asset_by_id(request : Request) -> JSONResponse | FileResponse:
from facefusion.session_context import get_session_id
asset_id = request.path_params.get('asset_id')
action = request.query_params.get('action')
asset = asset_store.get_asset(asset_id)
if not asset:
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
if asset.get('session_id') != get_session_id():
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
if action == 'download':
file_path = asset.get('path')
if not file_path or not os.path.exists(file_path):
return JSONResponse({'message': 'Asset file not found'}, status_code = HTTP_404_NOT_FOUND)
filename = asset.get('filename', 'download')
return FileResponse(file_path, filename = filename)
safe_asset = {k: v for k, v in asset.items() if k != 'path'}
return JSONResponse(safe_asset, status_code = HTTP_200_OK)
async def delete_asset_by_id(request : Request) -> JSONResponse:
from facefusion.session_context import get_session_id
asset_id = request.path_params.get('asset_id')
asset = asset_store.get_asset(asset_id)
if not asset:
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
if asset.get('session_id') != get_session_id():
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
success = asset_store.delete_asset(asset_id)
if not success:
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
return JSONResponse({'message': 'Asset deleted successfully'}, status_code = HTTP_200_OK)
+8 -8
View File
@@ -3,14 +3,14 @@ from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Route, WebSocketRoute
from facefusion.apis.assets import delete_asset_by_id, get_asset_by_id, list_all_assets, upload_asset
from facefusion.apis.choices import get_choices
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
from facefusion.apis.endpoints.ping import websocket_ping
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
from facefusion.apis.endpoints.state import get_state, set_state
from facefusion.apis.metrics import websocket_metrics
from facefusion.apis.ping import websocket_ping
from facefusion.apis.process import webrtc_offer, webrtc_stream_offer, websocket_process
from facefusion.apis.remote import remote
from facefusion.apis.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
from facefusion.apis.state import get_state, set_state
from facefusion.apis.timeline import get_timeline
from facefusion.apis.version import create_version_guard
@@ -26,11 +26,11 @@ def create_api() -> Starlette:
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]),
Route('/state', get_state, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ version_guard, session_guard ]),
Route('/assets', get_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]),
Route('/assets', list_all_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/assets/{asset_id}', get_asset_by_id, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/assets/{asset_id}', delete_asset_by_id, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]),
Route('/choices', get_choices, methods=['GET'], middleware=[ version_guard, session_guard ]),
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]),
Route('/choices', get_choices, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/remote', remote, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]),
Route('/timeline/{count:int}', get_timeline, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
Route('/webrtc/offer', webrtc_offer, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]),
+147
View File
@@ -0,0 +1,147 @@
import tempfile
from typing import List
from starlette.datastructures import UploadFile
from starlette.requests import Request
from starlette.responses import FileResponse, JSONResponse, Response
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
from facefusion import session_manager
from facefusion.apis import asset_store
from facefusion.apis.asset_helper import detect_media_type
from facefusion.apis.endpoints.session import extract_access_token
from facefusion.filesystem import get_file_extension, remove_file
async def upload_asset(request : Request) -> Response:
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
asset_type = request.query_params.get('type')
if session_id and asset_type in [ 'source', 'target' ]:
form = await request.form()
upload_files = form.getlist('file')
asset_paths = await save_asset_files(upload_files) # type: ignore[arg-type]
if asset_paths:
asset_ids : List[str] = []
for asset_path in asset_paths:
asset = asset_store.create_asset(session_id, asset_type, asset_path) # type: ignore[arg-type]
asset_id = asset.get('id')
if asset_id:
asset_ids.append(asset_id)
if asset_ids:
return JSONResponse(
{
'asset_ids': asset_ids
}, status_code = HTTP_201_CREATED)
return Response(status_code = HTTP_400_BAD_REQUEST)
async def save_asset_files(upload_files : List[UploadFile]) -> List[str]:
asset_paths : List[str] = []
for upload_file in upload_files:
upload_file_extension = get_file_extension(upload_file.filename)
with tempfile.NamedTemporaryFile(suffix = upload_file_extension, delete = False) as temp_file:
while upload_chunk := await upload_file.read(1024):
temp_file.write(upload_chunk)
temp_file.flush()
media_type = detect_media_type(temp_file.name)
if media_type:
asset_paths.append(temp_file.name)
else:
remove_file(temp_file.name)
return asset_paths
async def get_assets(request : Request) -> Response:
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
if session_id:
asset_set = asset_store.get_assets(session_id)
assets = []
if asset_set:
for asset in asset_set.values():
assets.append(
{
'id': asset.get('id'),
'created_at': asset.get('created_at').isoformat(),
'expires_at': asset.get('expires_at').isoformat(),
'type': asset.get('type'),
'media': asset.get('media'),
'name': asset.get('name'),
'format': asset.get('format'),
'size': asset.get('size'),
'metadata': asset.get('metadata')
})
return JSONResponse(
{
'assets': assets
}, status_code = HTTP_200_OK)
return Response(status_code = HTTP_400_BAD_REQUEST)
async def get_asset(request : Request) -> Response:
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
asset_id = request.path_params.get('asset_id')
action = request.query_params.get('action')
if session_id and asset_id:
asset = asset_store.get_asset(session_id, asset_id)
if asset:
if action == 'download':
asset_path = asset.get('path')
asset_name = asset.get('name')
return FileResponse(asset_path, filename = asset_name)
return JSONResponse(
{
'id': asset.get('id'),
'created_at': asset.get('created_at').isoformat(),
'expires_at': asset.get('expires_at').isoformat(),
'type': asset.get('type'),
'media': asset.get('media'),
'name': asset.get('name'),
'format': asset.get('format'),
'size': asset.get('size'),
'metadata': asset.get('metadata')
}, status_code = HTTP_200_OK)
return Response(status_code = HTTP_404_NOT_FOUND)
async def delete_assets(request : Request) -> Response:
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
body = await request.json()
asset_ids = body.get('asset_ids')
if session_id and asset_ids:
asset_set = asset_store.get_assets(session_id)
if asset_set:
for asset_id in asset_ids:
if asset_id in asset_set:
remove_file(asset_set.get(asset_id).get('path'))
asset_store.delete_assets(session_id, asset_ids)
return Response(status_code = HTTP_200_OK)
return Response(status_code = HTTP_404_NOT_FOUND)
@@ -30,7 +30,7 @@ async def create_session(request : Request) -> JSONResponse:
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
'message': translator.get('something_went_wrong', 'facefusion.apis')
}, status_code = HTTP_401_UNAUTHORIZED)
@@ -53,7 +53,7 @@ async def get_session(request : Request) -> JSONResponse:
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
'message': translator.get('something_went_wrong', 'facefusion.apis')
}, status_code = HTTP_401_UNAUTHORIZED)
@@ -73,7 +73,7 @@ async def refresh_session(request : Request) -> JSONResponse:
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
'message': translator.get('something_went_wrong', 'facefusion.apis')
}, status_code = HTTP_401_UNAUTHORIZED)
@@ -88,12 +88,12 @@ async def destroy_session(request : Request) -> JSONResponse:
return JSONResponse(
{
'message': translator.get('ok', __package__)
'message': translator.get('ok', 'facefusion.apis')
}, status_code = HTTP_200_OK)
return JSONResponse(
{
'message': translator.get('something_went_wrong', __package__)
'message': translator.get('something_went_wrong', 'facefusion.apis')
}, status_code = HTTP_401_UNAUTHORIZED)
@@ -106,20 +106,19 @@ def create_session_guard(app : ASGIApp) -> ASGIApp:
if session_id:
if session_manager.validate_session(session_id):
from facefusion.session_context import set_session_id
set_session_id(session_id)
session_context.set_session_id(session_id)
return await app(scope, receive, send)
response = JSONResponse(
{
'message': translator.get('invalid_access_token', __package__)
'message': translator.get('invalid_access_token', 'facefusion.apis')
}, status_code = HTTP_426_UPGRADE_REQUIRED)
return await response(scope, receive, send)
response = JSONResponse(
{
'message': translator.get('invalid_access_token', __package__)
'message': translator.get('invalid_access_token', 'facefusion.apis')
}, status_code = HTTP_401_UNAUTHORIZED)
return await response(scope, receive, send)
+80
View File
@@ -0,0 +1,80 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND
from facefusion import args_store, session_manager, state_manager, translator
from facefusion.apis import asset_store
from facefusion.apis.endpoints.session import extract_access_token
async def get_state(request : Request) -> JSONResponse:
api_args = args_store.filter_api_args(state_manager.get_state())
return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK)
async def set_state(request : Request) -> JSONResponse:
action = request.query_params.get('action')
asset_type = request.query_params.get('asset_type')
if action == 'select' and asset_type == 'source':
return await select_source(request)
if action == 'select' and asset_type == 'target':
return await select_target(request)
body = await request.json()
api_args = args_store.get_api_args()
for key, value in body.items():
if key in api_args:
state_manager.set_item(key, value)
__api_args__ = args_store.filter_api_args(state_manager.get_state())
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
async def select_source(request : Request) -> JSONResponse:
body = await request.json()
asset_ids = body.get('asset_ids')
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
if isinstance(asset_ids, list) and session_id:
source_paths = []
for asset_id in asset_ids:
asset = asset_store.get_asset(session_id, asset_id)
if asset:
source_paths.append(asset.get('path'))
state_manager.set_item('source_paths', source_paths)
__api_args__ = args_store.filter_api_args(state_manager.get_state())
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
return JSONResponse(
{
'message': translator.get('source_asset_not_found', 'facefusion.apis')
}, status_code = HTTP_404_NOT_FOUND)
async def select_target(request : Request) -> JSONResponse:
body = await request.json()
asset_id = body.get('asset_id')
access_token = extract_access_token(request.scope)
session_id = session_manager.find_session_id(access_token)
if isinstance(asset_id, str) and session_id:
asset = asset_store.get_asset(session_id, asset_id)
if asset:
state_manager.set_item('target_path', asset.get('path'))
__api_args__ = args_store.filter_api_args(state_manager.get_state())
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
return JSONResponse(
{
'message': translator.get('target_asset_not_found', 'facefusion.apis')
}, status_code = HTTP_404_NOT_FOUND)
+14
View File
@@ -0,0 +1,14 @@
from facefusion.types import Locales
LOCALES : Locales =\
{
'en':
{
'ok': 'ok',
'something_went_wrong': 'something went wrong',
'invalid_access_token': 'invalid access token',
'invalid_refresh_token': 'invalid refresh token',
'source_asset_not_found': 'source asset not found',
'target_asset_not_found': 'target asset not found'
}
}
+21 -20
View File
@@ -9,8 +9,10 @@ from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from facefusion import asset_store, logger
from facefusion import logger
from facefusion.apis import asset_store
from facefusion.choices import audio_formats
from facefusion.session_context import get_session_id
def resolve_image_urls(url : str) -> List[str]:
@@ -66,12 +68,14 @@ def download_images_from_url(url : str, asset_type : str) -> List[str]:
gdl_job = gallery_job.DownloadJob(url)
gdl_job.run()
session_id = get_session_id()
for root, dirs, files in os.walk(output_dir):
for filename in files:
file_path = os.path.join(root, filename)
asset_id = asset_store.register(asset_type, file_path, filename)
asset_ids.append(asset_id)
logger.info(f'Registered image as asset {asset_id}', __name__)
asset = asset_store.create_asset(session_id, asset_type, file_path)
if asset:
asset_ids.append(asset.get('id'))
logger.info(f'Registered image as asset {asset.get("id")}', __name__)
break
@@ -108,9 +112,11 @@ def download_images_from_url(url : str, asset_type : str) -> List[str]:
for chunk in response.iter_bytes(chunk_size = 8192):
f.write(chunk)
asset_id = asset_store.register(asset_type, file_path, filename)
asset_ids.append(asset_id)
logger.info(f'Downloaded and registered image as asset {asset_id}', __name__)
session_id = get_session_id()
asset = asset_store.create_asset(session_id, asset_type, file_path)
if asset:
asset_ids.append(asset.get('id'))
logger.info(f'Downloaded and registered image as asset {asset.get("id")}', __name__)
return asset_ids
@@ -138,9 +144,11 @@ def download_audio_from_url(url : str, asset_type : str) -> List[str]:
for chunk in response.iter_bytes(chunk_size = 8192):
f.write(chunk)
asset_id = asset_store.register(asset_type, file_path, filename)
asset_ids.append(asset_id)
logger.info(f'Downloaded and registered audio as asset {asset_id}', __name__)
session_id = get_session_id()
asset = asset_store.create_asset(session_id, asset_type, file_path)
if asset:
asset_ids.append(asset.get('id'))
logger.info(f'Downloaded and registered audio as asset {asset.get("id")}', __name__)
return asset_ids
@@ -352,16 +360,9 @@ async def remote(request : Request) -> JSONResponse:
total_frames = int(duration * fps)
logger.info(f'Calculated total frames: {total_frames} ({duration}s * {fps} fps)', __name__)
filename = os.path.basename(downloaded_file)
metadata =\
{
'frame_total': total_frames,
'fps': fps,
'resolution': (width, height) if width and height else None,
'duration': duration
}
asset_id = asset_store.register(asset_type, downloaded_file, filename, metadata)
session_id = get_session_id()
asset = asset_store.create_asset(session_id, asset_type, downloaded_file)
asset_id = asset.get('id') if asset else None
logger.info(f'Video downloaded and registered as asset {asset_id}', __name__)
response_data =\
-76
View File
@@ -1,76 +0,0 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
from facefusion import args_store, asset_store, logger, state_manager
async def get_state(request : Request) -> JSONResponse:
api_args = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK)
async def set_state(request : Request) -> JSONResponse:
body = await request.json()
action = request.query_params.get('action')
if action == 'select':
asset_type = request.query_params.get('asset_type')
if not asset_type:
return JSONResponse({'message': 'Missing required query parameter: asset_type'}, status_code = HTTP_400_BAD_REQUEST)
if asset_type not in ['source', 'target']:
return JSONResponse({'message': 'Invalid asset_type. Must be "source" or "target"'}, status_code = HTTP_400_BAD_REQUEST)
if asset_type == 'source':
asset_ids = body.get('asset_ids', [])
if not isinstance(asset_ids, list):
return JSONResponse({'message': 'asset_ids must be an array'}, status_code = HTTP_400_BAD_REQUEST)
if not asset_ids:
return JSONResponse({'message': 'asset_ids cannot be empty'}, status_code = HTTP_400_BAD_REQUEST)
source_paths = []
for asset_id in asset_ids:
asset = asset_store.get_asset(asset_id)
if not asset:
return JSONResponse({'message': f'Source asset not found: {asset_id}'}, status_code = HTTP_404_NOT_FOUND)
source_paths.append(asset['path'])
state_manager.set_item('source_paths', source_paths)
__api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
if asset_type == 'target':
asset_id = body.get('asset_id')
if not asset_id:
return JSONResponse({'message': 'Missing required field: asset_id'}, status_code = HTTP_400_BAD_REQUEST)
if not isinstance(asset_id, str):
return JSONResponse({'message': 'asset_id must be a string'}, status_code = HTTP_400_BAD_REQUEST)
asset = asset_store.get_asset(asset_id)
if not asset:
return JSONResponse({'message': f'Target asset not found: {asset_id}'}, status_code = HTTP_404_NOT_FOUND)
state_manager.set_item('target_path', asset['path'])
__api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
api_args = args_store.get_api_args()
logger.info(f'[State] Normal update - body keys: {list(body.keys())}', __name__)
for key, value in body.items():
if key in api_args:
state_manager.set_item(key, value)
logger.debug(f'[State] Set {key} = {value}', __name__)
else:
logger.warn(f'[State] Skipped {key} (not in api_args)', __name__)
__api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
+3 -6
View File
@@ -9,7 +9,7 @@ from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
from facefusion import logger
from facefusion.asset_store import get_asset
from facefusion.apis import asset_store
from facefusion.filesystem import is_video
from facefusion.video_manager import get_video_capture
from facefusion.vision import fit_contain_frame
@@ -81,14 +81,11 @@ async def get_timeline(request: Request) -> JSONResponse:
if asset_id and not target_path:
from facefusion.session_context import get_session_id
asset = get_asset(asset_id)
session_id = get_session_id()
asset = asset_store.get_asset(session_id, asset_id)
if not asset:
return JSONResponse({'message': f'Asset not found: {asset_id}'}, status_code=HTTP_400_BAD_REQUEST)
# Verify asset belongs to current session (security)
if asset.get('session_id') != get_session_id():
return JSONResponse({'message': 'Asset not found'}, status_code=HTTP_400_BAD_REQUEST)
target_path = asset.get('path')
if not target_path:
return JSONResponse({'message': 'Asset has no path'}, status_code=HTTP_400_BAD_REQUEST)
+64
View File
@@ -162,6 +162,70 @@ AudioTypeSet : TypeAlias = Dict[AudioFormat, str]
ImageTypeSet : TypeAlias = Dict[ImageFormat, str]
VideoTypeSet : TypeAlias = Dict[VideoFormat, str]
AssetId : TypeAlias = str
AssetType = Literal['source', 'target']
AudioMetadata = TypedDict('AudioMetadata',
{
'duration' : Duration,
'sample_rate': int,
'frame_total': int,
'channels': int,
'format': str
})
ImageMetadata = TypedDict('ImageMetadata',
{
'resolution' : Resolution
})
VideoMetadata = TypedDict('VideoMetadata',
{
'duration' : Duration,
'frame_total' : int,
'fps' : Fps,
'resolution' : Resolution
})
AudioAsset = TypedDict('AudioAsset',
{
'id' : AssetId,
'created_at' : datetime,
'expires_at' : datetime,
'type' : AssetType,
'media' : Literal['audio'],
'name' : str,
'format' : AudioFormat,
'size' : int,
'path' : str,
'metadata' : AudioMetadata
})
ImageAsset = TypedDict('ImageAsset',
{
'id' : AssetId,
'created_at' : datetime,
'expires_at' : datetime,
'type' : AssetType,
'media' : Literal['image'],
'name' : str,
'format' : ImageFormat,
'size' : int,
'path' : str,
'metadata' : ImageMetadata
})
VideoAsset = TypedDict('VideoAsset',
{
'id' : AssetId,
'created_at' : datetime,
'expires_at' : datetime,
'type' : AssetType,
'media' : Literal['video'],
'name' : str,
'format' : VideoFormat,
'size' : int,
'path' : str,
'metadata' : VideoMetadata
})
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
AssetSet : TypeAlias = Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]
AssetStore : TypeAlias = Dict[SessionId, AssetSet]
AudioEncoder = Literal['flac', 'aac', 'libmp3lame', 'libopus', 'libvorbis', 'pcm_s16le', 'pcm_s32le']
VideoEncoder = Literal['libx264', 'libx264rgb', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf', 'h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox', 'rawvideo']
EncoderSet = TypedDict('EncoderSet',
+109 -159
View File
@@ -2,24 +2,24 @@ import os
import tempfile
from typing import Iterator
import cv2
import numpy
import pytest
from facefusion import asset_store, session_manager, state_manager
from facefusion.session_context import clear_session_id, set_session_id
from facefusion.apis import asset_store
@pytest.fixture(scope = 'function', autouse = True)
def before_each() -> None:
session_manager.SESSIONS.clear()
state_manager.clear_item('asset_registry')
clear_session_id()
asset_store.clear()
@pytest.fixture(scope = 'function')
def temp_file() -> Iterator[str]:
fd, path = tempfile.mkstemp(suffix = '.jpg')
os.write(fd, b'test file content')
os.close(fd)
image = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path, image)
yield path
if os.path.exists(path):
os.remove(path)
@@ -27,221 +27,171 @@ def temp_file() -> Iterator[str]:
@pytest.fixture(scope = 'function')
def session_id() -> str:
test_session_id = 'test-session-123'
set_session_id(test_session_id)
return test_session_id
return 'test-session-123'
def test_register_source_asset(temp_file : str, session_id : str) -> None:
asset_id = asset_store.register('source', temp_file, 'test.jpg')
def test_create_source_asset(temp_file : str, session_id : str) -> None:
asset = asset_store.create_asset(session_id, 'source', temp_file)
assert isinstance(asset_id, str)
assert len(asset_id) == 36
asset = asset_store.get_asset(asset_id)
assert asset is not None
assert asset.get('id') == asset_id
assert asset.get('session_id') == session_id
assert isinstance(asset.get('id'), str)
assert len(asset.get('id')) == 36
assert asset.get('type') == 'source'
assert asset.get('path') == temp_file
assert asset.get('filename') == 'test.jpg'
assert asset.get('size') > 0
assert asset.get('created_at')
def test_register_target_asset(temp_file : str, session_id : str) -> None:
asset_id = asset_store.register('target', temp_file, 'video.mp4')
def test_create_target_asset(temp_file : str, session_id : str) -> None:
asset = asset_store.create_asset(session_id, 'target', temp_file)
asset = asset_store.get_asset(asset_id)
assert asset is not None
assert asset.get('type') == 'target'
assert asset.get('filename') == 'video.mp4'
def test_register_output_asset(temp_file : str, session_id : str) -> None:
metadata = {'fps': 30, 'resolution': [1920, 1080]}
asset_id = asset_store.register('output', temp_file, 'output.mp4', metadata)
def test_get_asset(temp_file : str, session_id : str) -> None:
created_asset = asset_store.create_asset(session_id, 'source', temp_file)
asset_id = created_asset.get('id')
asset = asset_store.get_asset(asset_id)
assert asset.get('type') == 'output'
assert asset.get('metadata') == metadata
def test_register_invalid_type(temp_file : str, session_id : str) -> None:
with pytest.raises(ValueError) as exc:
asset_store.register('invalid_type', temp_file, 'test.jpg')
assert "Invalid asset_type" in str(exc.value)
def test_register_without_session() -> None:
fd, path = tempfile.mkstemp()
os.close(fd)
try:
with pytest.raises(ValueError) as exc:
asset_store.register('source', path, 'test.jpg')
assert "No active session" in str(exc.value)
finally:
os.remove(path)
def test_register_without_filename(temp_file : str, session_id : str) -> None:
asset_id = asset_store.register('source', temp_file)
asset = asset_store.get_asset(asset_id)
assert asset.get('filename') == os.path.basename(temp_file)
asset = asset_store.get_asset(session_id, asset_id)
assert asset is not None
assert asset.get('id') == asset_id
assert asset.get('type') == 'source'
def test_get_asset_not_found(session_id : str) -> None:
asset = asset_store.get_asset('non-existent-id')
asset = asset_store.get_asset(session_id, 'non-existent-id')
assert asset is None
def test_list_assets_empty(session_id : str) -> None:
assets = asset_store.list_assets()
assert assets == []
def test_get_asset_wrong_session(temp_file : str, session_id : str) -> None:
created_asset = asset_store.create_asset(session_id, 'source', temp_file)
asset_id = created_asset.get('id')
asset = asset_store.get_asset('different-session', asset_id)
assert asset is None
def test_list_assets_with_multiple(temp_file : str, session_id : str) -> None:
def test_get_assets_empty(session_id : str) -> None:
assets = asset_store.get_assets(session_id)
assert assets is None
def test_get_assets_with_multiple(temp_file : str, session_id : str) -> None:
fd1, path1 = tempfile.mkstemp(suffix = '.jpg')
os.write(fd1, b'content 1')
os.close(fd1)
image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path1, image1)
fd2, path2 = tempfile.mkstemp(suffix = '.mp4')
os.write(fd2, b'content 2')
fd2, path2 = tempfile.mkstemp(suffix = '.jpg')
os.close(fd2)
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path2, image2)
try:
asset_store.register('source', path1, 'source1.jpg')
asset_store.register('source', path2, 'source2.jpg')
asset_store.register('target', temp_file, 'target.mp4')
asset_store.create_asset(session_id, 'source', path1)
asset_store.create_asset(session_id, 'source', path2)
asset_store.create_asset(session_id, 'target', temp_file)
assets = asset_store.list_assets()
assets = asset_store.get_assets(session_id)
assert assets is not None
assert len(assets) == 3
finally:
os.remove(path1)
os.remove(path2)
if os.path.exists(path1):
os.remove(path1)
if os.path.exists(path2):
os.remove(path2)
def test_list_assets_filter_by_type(temp_file : str, session_id : str) -> None:
fd, path = tempfile.mkstemp(suffix = '.jpg')
os.write(fd, b'content')
os.close(fd)
try:
asset_store.register('source', path, 'source.jpg')
asset_store.register('target', temp_file, 'target.mp4')
source_assets = asset_store.list_assets('source')
assert len(source_assets) == 1
assert source_assets[0].get('type') == 'source'
target_assets = asset_store.list_assets('target')
assert len(target_assets) == 1
assert target_assets[0].get('type') == 'target'
output_assets = asset_store.list_assets('output')
assert len(output_assets) == 0
finally:
os.remove(path)
def test_list_assets_invalid_type(session_id : str) -> None:
with pytest.raises(ValueError) as exc:
asset_store.list_assets('invalid_type')
assert "Invalid asset_type" in str(exc.value)
def test_list_assets_session_scoped(temp_file : str) -> None:
def test_get_assets_session_scoped(temp_file : str) -> None:
session1_id = 'session-1'
set_session_id(session1_id)
asset1_id = asset_store.register('source', temp_file, 'file1.jpg')
asset1 = asset_store.create_asset(session1_id, 'source', temp_file)
asset1_id = asset1.get('id')
session2_id = 'session-2'
set_session_id(session2_id)
fd, path2 = tempfile.mkstemp(suffix = '.jpg')
os.write(fd, b'content 2')
os.close(fd)
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path2, image2)
try:
asset2_id = asset_store.register('source', path2, 'file2.jpg')
asset2 = asset_store.create_asset(session2_id, 'source', path2)
asset2_id = asset2.get('id')
assets_session2 = asset_store.list_assets()
assets_session2 = asset_store.get_assets(session2_id)
assert assets_session2 is not None
assert len(assets_session2) == 1
assert assets_session2[0].get('id') == asset2_id
assert asset2_id in assets_session2
set_session_id(session1_id)
assets_session1 = asset_store.list_assets()
assets_session1 = asset_store.get_assets(session1_id)
assert assets_session1 is not None
assert len(assets_session1) == 1
assert assets_session1[0].get('id') == asset1_id
assert asset1_id in assets_session1
finally:
os.remove(path2)
if os.path.exists(path2):
os.remove(path2)
def test_delete_asset(temp_file : str, session_id : str) -> None:
asset_id = asset_store.register('source', temp_file, 'test.jpg')
assert os.path.exists(temp_file)
success = asset_store.delete_asset(asset_id)
assert success is True
assert not os.path.exists(temp_file)
asset = asset_store.get_asset(asset_id)
assert asset is None
def test_delete_asset_not_found(session_id : str) -> None:
success = asset_store.delete_asset('non-existent-id')
assert success is False
def test_cleanup_session_assets(session_id : str) -> None:
def test_delete_assets(session_id : str) -> None:
fd1, path1 = tempfile.mkstemp(suffix = '.jpg')
os.write(fd1, b'content 1')
os.close(fd1)
image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path1, image1)
fd2, path2 = tempfile.mkstemp(suffix = '.mp4')
os.write(fd2, b'content 2')
fd2, path2 = tempfile.mkstemp(suffix = '.jpg')
os.close(fd2)
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path2, image2)
asset_id1 = asset_store.register('source', path1, 'source.jpg')
asset_id2 = asset_store.register('target', path2, 'target.mp4')
try:
asset1 = asset_store.create_asset(session_id, 'source', path1)
asset2 = asset_store.create_asset(session_id, 'source', path2)
asset1_id = asset1.get('id')
asset2_id = asset2.get('id')
assert os.path.exists(path1)
assert os.path.exists(path2)
asset_store.delete_assets(session_id, [asset1_id])
asset_store.cleanup_session_assets(session_id)
assert not os.path.exists(path1)
assert not os.path.exists(path2)
assert asset_store.get_asset(asset_id1) is None
assert asset_store.get_asset(asset_id2) is None
assets = asset_store.get_assets(session_id)
assert assets is not None
assert len(assets) == 1
assert asset2_id in assets
assert asset1_id not in assets
finally:
if os.path.exists(path1):
os.remove(path1)
if os.path.exists(path2):
os.remove(path2)
def test_cleanup_session_assets_only_affects_target_session(temp_file : str) -> None:
session1_id = 'session-1'
set_session_id(session1_id)
def test_delete_assets_not_found(session_id : str) -> None:
asset_store.delete_assets(session_id, ['non-existent-id'])
fd, path1 = tempfile.mkstemp(suffix = '.jpg')
os.write(fd, b'content 1')
os.close(fd)
asset1_id = asset_store.register('source', path1, 'file1.jpg')
def test_clear() -> None:
fd1, path1 = tempfile.mkstemp(suffix = '.jpg')
os.close(fd1)
image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path1, image1)
session2_id = 'session-2'
set_session_id(session2_id)
asset2_id = asset_store.register('source', temp_file, 'file2.jpg')
fd2, path2 = tempfile.mkstemp(suffix = '.jpg')
os.close(fd2)
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
cv2.imwrite(path2, image2)
asset_store.cleanup_session_assets(session1_id)
try:
session1_id = 'session-1'
session2_id = 'session-2'
assert not os.path.exists(path1)
assert os.path.exists(temp_file)
asset_store.create_asset(session1_id, 'source', path1)
asset_store.create_asset(session2_id, 'source', path2)
set_session_id(session1_id)
assert asset_store.get_asset(asset1_id) is None
asset_store.clear()
set_session_id(session2_id)
assert asset_store.get_asset(asset2_id) is not None
assert asset_store.get_assets(session1_id) is None
assert asset_store.get_assets(session2_id) is None
finally:
if os.path.exists(path1):
os.remove(path1)
if os.path.exists(path2):
os.remove(path2)