mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 04:55:57 +02:00
refactor api
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.audio import detect_audio_duration
|
||||
from facefusion.filesystem import is_audio, is_image, is_video
|
||||
from facefusion.types import AudioMetadata, ImageMetadata, MediaType, VideoMetadata
|
||||
from facefusion.vision import count_video_frame_total, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution
|
||||
|
||||
|
||||
def extract_audio_metadata(file_path : str) -> AudioMetadata:
|
||||
metadata : AudioMetadata =\
|
||||
{
|
||||
'duration': detect_audio_duration(file_path),
|
||||
'sample_rate': 0,
|
||||
'frame_total': 0,
|
||||
'channels': 0,
|
||||
'format': ''
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_image_metadata(file_path : str) -> ImageMetadata:
|
||||
resolution = detect_image_resolution(file_path)
|
||||
metadata : ImageMetadata =\
|
||||
{
|
||||
'resolution': resolution if resolution else (0, 0)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_video_metadata(file_path : str) -> VideoMetadata:
|
||||
resolution = detect_video_resolution(file_path)
|
||||
fps = detect_video_fps(file_path)
|
||||
metadata : VideoMetadata =\
|
||||
{
|
||||
'duration': detect_video_duration(file_path),
|
||||
'frame_total': count_video_frame_total(file_path),
|
||||
'fps': fps if fps else 0.0,
|
||||
'resolution': resolution if resolution else (0, 0)
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
def detect_media_type(file_path : str) -> Optional[MediaType]:
|
||||
if is_audio(file_path):
|
||||
return 'audio'
|
||||
if is_image(file_path):
|
||||
return 'image'
|
||||
if is_video(file_path):
|
||||
return 'video'
|
||||
return None
|
||||
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name
|
||||
from facefusion.types import AssetId, AssetSet, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
|
||||
|
||||
ASSET_STORE : AssetStore = {}
|
||||
|
||||
|
||||
def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : str) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
asset_id = str(uuid.uuid4())
|
||||
asset_name = get_file_name(asset_path)
|
||||
asset_format = get_file_format(asset_path)
|
||||
asset_size = os.path.getsize(asset_path)
|
||||
media_type = detect_media_type(asset_path)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours = 2)
|
||||
|
||||
if session_id not in ASSET_STORE:
|
||||
ASSET_STORE[session_id] = {}
|
||||
|
||||
if media_type == 'audio':
|
||||
ASSET_STORE[session_id][asset_id] = cast(AudioAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': asset_name,
|
||||
'format': cast(AudioFormat, asset_format),
|
||||
'size': asset_size,
|
||||
'path': asset_path,
|
||||
'metadata': extract_audio_metadata(asset_path)
|
||||
})
|
||||
|
||||
if media_type == 'image':
|
||||
ASSET_STORE[session_id][asset_id] = cast(ImageAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': asset_name,
|
||||
'format': cast(ImageFormat, asset_format),
|
||||
'size': asset_size,
|
||||
'path': asset_path,
|
||||
'metadata': extract_image_metadata(asset_path)
|
||||
})
|
||||
|
||||
if media_type == 'video':
|
||||
ASSET_STORE[session_id][asset_id] = cast(VideoAsset,
|
||||
{
|
||||
'id': asset_id,
|
||||
'created_at': created_at,
|
||||
'expires_at': expires_at,
|
||||
'type': asset_type,
|
||||
'media': media_type,
|
||||
'name': asset_name,
|
||||
'format': cast(VideoFormat, asset_format),
|
||||
'size': asset_size,
|
||||
'path': asset_path,
|
||||
'metadata': extract_video_metadata(asset_path)
|
||||
})
|
||||
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
|
||||
|
||||
def get_assets(session_id : SessionId) -> Optional[AssetSet]:
|
||||
return ASSET_STORE.get(session_id)
|
||||
|
||||
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE.get(session_id).get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def delete_assets(session_id : SessionId, asset_ids : List[AssetId]) -> None:
|
||||
if session_id in ASSET_STORE:
|
||||
for asset_id in asset_ids:
|
||||
if asset_id in ASSET_STORE.get(session_id):
|
||||
del ASSET_STORE[session_id][asset_id]
|
||||
return None
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
ASSET_STORE.clear()
|
||||
@@ -1,186 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse, JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import asset_store, filesystem, logger
|
||||
from facefusion.vision import count_video_frame_total, detect_video_fps, detect_video_resolution
|
||||
|
||||
|
||||
async def upload_asset(request : Request) -> JSONResponse:
|
||||
asset_type = request.query_params.get('type')
|
||||
|
||||
if not asset_type:
|
||||
return JSONResponse({'message': 'Missing required query parameter: type'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if asset_type not in ['source', 'target']:
|
||||
return JSONResponse({'message': 'Invalid type. Must be "source" or "target"'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
form = await request.form()
|
||||
|
||||
if asset_type == 'source':
|
||||
files = form.getlist('file')
|
||||
|
||||
if not files:
|
||||
return JSONResponse({'message': 'No file provided'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
asset_ids = []
|
||||
|
||||
for file in files:
|
||||
filename = file.filename if hasattr(file, 'filename') else 'source.jpg'
|
||||
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
file_path = temp_file.name
|
||||
|
||||
if not (filesystem.is_image(file_path) or filesystem.is_video(file_path) or filesystem.is_audio(file_path)):
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': 'Unsupported file format. Allowed formats - Images: bmp, jpeg, png, tiff, webp. Videos: avi, m4v, mkv, mov, mp4, mpeg, mxf, webm, wmv.'
|
||||
},
|
||||
status_code = HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
asset_id = asset_store.register('source', file_path, filename)
|
||||
asset_ids.append(asset_id)
|
||||
|
||||
logger.debug(f'Uploaded {len(asset_ids)} source(s)', __name__)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': f'{len(asset_ids)} source(s) uploaded successfully',
|
||||
'asset_ids': asset_ids
|
||||
},
|
||||
status_code = HTTP_201_CREATED
|
||||
)
|
||||
|
||||
if asset_type == 'target':
|
||||
file = form.get('file')
|
||||
|
||||
if not file:
|
||||
return JSONResponse({'message': 'No file provided'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if isinstance(file, str):
|
||||
return JSONResponse({'message': 'Expected file upload, got string. Use /stream/target for URLs'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not hasattr(file, 'filename'):
|
||||
return JSONResponse({'message': 'Invalid file object'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
filename = file.filename
|
||||
file_extension = os.path.splitext(filename)[1] if filename else '.jpg'
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
file_path = temp_file.name
|
||||
|
||||
if not (filesystem.is_image(file_path) or filesystem.is_video(file_path) or filesystem.is_audio(file_path)):
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': 'Unsupported file format. Allowed formats - Images: bmp, jpeg, png, tiff, webp. Videos: avi, m4v, mkv, mov, mp4, mpeg, mxf, webm, wmv.'
|
||||
},
|
||||
status_code = HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
metadata = None
|
||||
|
||||
if filesystem.is_video(file_path):
|
||||
frame_total = count_video_frame_total(file_path)
|
||||
fps = detect_video_fps(file_path)
|
||||
resolution = detect_video_resolution(file_path)
|
||||
metadata =\
|
||||
{
|
||||
'frame_total': frame_total,
|
||||
'fps': fps,
|
||||
'resolution': resolution
|
||||
}
|
||||
logger.debug(f'Video metadata - frames: {frame_total}, fps: {fps}, resolution: {resolution}', __name__)
|
||||
|
||||
asset_id = asset_store.register('target', file_path, filename, metadata)
|
||||
|
||||
logger.debug(f'Target uploaded with asset_id: {asset_id}', __name__)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': 'Target uploaded successfully',
|
||||
'asset_id': asset_id
|
||||
},
|
||||
status_code = HTTP_201_CREATED
|
||||
)
|
||||
|
||||
|
||||
async def list_all_assets(request : Request) -> JSONResponse:
|
||||
asset_type = request.query_params.get('type')
|
||||
media_type = request.query_params.get('media_type')
|
||||
format = request.query_params.get('format')
|
||||
|
||||
assets = asset_store.list_assets(asset_type)
|
||||
|
||||
if media_type:
|
||||
assets = [a for a in assets if a.get('media_type') == media_type]
|
||||
|
||||
if format:
|
||||
assets = [a for a in assets if a.get('format') == format]
|
||||
|
||||
safe_assets = []
|
||||
for asset in assets:
|
||||
safe_asset = {k: v for k, v in asset.items() if k != 'path'}
|
||||
safe_assets.append(safe_asset)
|
||||
|
||||
return JSONResponse({'assets': safe_assets, 'count': len(safe_assets)}, status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def get_asset_by_id(request : Request) -> JSONResponse | FileResponse:
|
||||
from facefusion.session_context import get_session_id
|
||||
|
||||
asset_id = request.path_params.get('asset_id')
|
||||
action = request.query_params.get('action')
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
|
||||
if not asset:
|
||||
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
if asset.get('session_id') != get_session_id():
|
||||
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
if action == 'download':
|
||||
file_path = asset.get('path')
|
||||
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
return JSONResponse({'message': 'Asset file not found'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
filename = asset.get('filename', 'download')
|
||||
|
||||
return FileResponse(file_path, filename = filename)
|
||||
|
||||
safe_asset = {k: v for k, v in asset.items() if k != 'path'}
|
||||
|
||||
return JSONResponse(safe_asset, status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def delete_asset_by_id(request : Request) -> JSONResponse:
|
||||
from facefusion.session_context import get_session_id
|
||||
|
||||
asset_id = request.path_params.get('asset_id')
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
|
||||
if not asset:
|
||||
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
if asset.get('session_id') != get_session_id():
|
||||
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
success = asset_store.delete_asset(asset_id)
|
||||
|
||||
if not success:
|
||||
return JSONResponse({'message': 'Asset not found'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
return JSONResponse({'message': 'Asset deleted successfully'}, status_code = HTTP_200_OK)
|
||||
@@ -3,14 +3,14 @@ from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Route, WebSocketRoute
|
||||
|
||||
from facefusion.apis.assets import delete_asset_by_id, get_asset_by_id, list_all_assets, upload_asset
|
||||
from facefusion.apis.choices import get_choices
|
||||
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
from facefusion.apis.metrics import websocket_metrics
|
||||
from facefusion.apis.ping import websocket_ping
|
||||
from facefusion.apis.process import webrtc_offer, webrtc_stream_offer, websocket_process
|
||||
from facefusion.apis.remote import remote
|
||||
from facefusion.apis.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.state import get_state, set_state
|
||||
from facefusion.apis.timeline import get_timeline
|
||||
from facefusion.apis.version import create_version_guard
|
||||
|
||||
@@ -26,11 +26,11 @@ def create_api() -> Starlette:
|
||||
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/state', get_state, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/assets', get_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/assets', list_all_assets, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/assets/{asset_id}', get_asset_by_id, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/assets/{asset_id}', delete_asset_by_id, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/choices', get_choices, methods=['GET'], middleware=[ version_guard, session_guard ]),
|
||||
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/choices', get_choices, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/remote', remote, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/timeline/{count:int}', get_timeline, methods = [ 'GET' ], middleware = [ version_guard, session_guard ]),
|
||||
Route('/webrtc/offer', webrtc_offer, methods = [ 'POST' ], middleware = [ version_guard, session_guard ]),
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from starlette.datastructures import UploadFile
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse, JSONResponse, Response
|
||||
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import session_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.asset_helper import detect_media_type
|
||||
from facefusion.apis.endpoints.session import extract_access_token
|
||||
from facefusion.filesystem import get_file_extension, remove_file
|
||||
|
||||
|
||||
async def upload_asset(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
asset_type = request.query_params.get('type')
|
||||
|
||||
if session_id and asset_type in [ 'source', 'target' ]:
|
||||
form = await request.form()
|
||||
upload_files = form.getlist('file')
|
||||
asset_paths = await save_asset_files(upload_files) # type: ignore[arg-type]
|
||||
|
||||
if asset_paths:
|
||||
asset_ids : List[str] = []
|
||||
|
||||
for asset_path in asset_paths:
|
||||
asset = asset_store.create_asset(session_id, asset_type, asset_path) # type: ignore[arg-type]
|
||||
asset_id = asset.get('id')
|
||||
|
||||
if asset_id:
|
||||
asset_ids.append(asset_id)
|
||||
|
||||
if asset_ids:
|
||||
return JSONResponse(
|
||||
{
|
||||
'asset_ids': asset_ids
|
||||
}, status_code = HTTP_201_CREATED)
|
||||
|
||||
return Response(status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
async def save_asset_files(upload_files : List[UploadFile]) -> List[str]:
|
||||
asset_paths : List[str] = []
|
||||
|
||||
for upload_file in upload_files:
|
||||
upload_file_extension = get_file_extension(upload_file.filename)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix = upload_file_extension, delete = False) as temp_file:
|
||||
|
||||
while upload_chunk := await upload_file.read(1024):
|
||||
temp_file.write(upload_chunk)
|
||||
|
||||
temp_file.flush()
|
||||
|
||||
media_type = detect_media_type(temp_file.name)
|
||||
|
||||
if media_type:
|
||||
asset_paths.append(temp_file.name)
|
||||
else:
|
||||
remove_file(temp_file.name)
|
||||
|
||||
return asset_paths
|
||||
|
||||
|
||||
async def get_assets(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if session_id:
|
||||
asset_set = asset_store.get_assets(session_id)
|
||||
assets = []
|
||||
|
||||
if asset_set:
|
||||
for asset in asset_set.values():
|
||||
assets.append(
|
||||
{
|
||||
'id': asset.get('id'),
|
||||
'created_at': asset.get('created_at').isoformat(),
|
||||
'expires_at': asset.get('expires_at').isoformat(),
|
||||
'type': asset.get('type'),
|
||||
'media': asset.get('media'),
|
||||
'name': asset.get('name'),
|
||||
'format': asset.get('format'),
|
||||
'size': asset.get('size'),
|
||||
'metadata': asset.get('metadata')
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'assets': assets
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return Response(status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
async def get_asset(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
asset_id = request.path_params.get('asset_id')
|
||||
action = request.query_params.get('action')
|
||||
|
||||
if session_id and asset_id:
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
if asset:
|
||||
if action == 'download':
|
||||
asset_path = asset.get('path')
|
||||
asset_name = asset.get('name')
|
||||
|
||||
return FileResponse(asset_path, filename = asset_name)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'id': asset.get('id'),
|
||||
'created_at': asset.get('created_at').isoformat(),
|
||||
'expires_at': asset.get('expires_at').isoformat(),
|
||||
'type': asset.get('type'),
|
||||
'media': asset.get('media'),
|
||||
'name': asset.get('name'),
|
||||
'format': asset.get('format'),
|
||||
'size': asset.get('size'),
|
||||
'metadata': asset.get('metadata')
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
async def delete_assets(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
body = await request.json()
|
||||
asset_ids = body.get('asset_ids')
|
||||
|
||||
if session_id and asset_ids:
|
||||
asset_set = asset_store.get_assets(session_id)
|
||||
|
||||
if asset_set:
|
||||
for asset_id in asset_ids:
|
||||
if asset_id in asset_set:
|
||||
remove_file(asset_set.get(asset_id).get('path'))
|
||||
asset_store.delete_assets(session_id, asset_ids)
|
||||
return Response(status_code = HTTP_200_OK)
|
||||
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
@@ -30,7 +30,7 @@ async def create_session(request : Request) -> JSONResponse:
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
'message': translator.get('something_went_wrong', 'facefusion.apis')
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ async def get_session(request : Request) -> JSONResponse:
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
'message': translator.get('something_went_wrong', 'facefusion.apis')
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ async def refresh_session(request : Request) -> JSONResponse:
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
'message': translator.get('something_went_wrong', 'facefusion.apis')
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@@ -88,12 +88,12 @@ async def destroy_session(request : Request) -> JSONResponse:
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('ok', __package__)
|
||||
'message': translator.get('ok', 'facefusion.apis')
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('something_went_wrong', __package__)
|
||||
'message': translator.get('something_went_wrong', 'facefusion.apis')
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@@ -106,20 +106,19 @@ def create_session_guard(app : ASGIApp) -> ASGIApp:
|
||||
|
||||
if session_id:
|
||||
if session_manager.validate_session(session_id):
|
||||
from facefusion.session_context import set_session_id
|
||||
set_session_id(session_id)
|
||||
session_context.set_session_id(session_id)
|
||||
return await app(scope, receive, send)
|
||||
|
||||
response = JSONResponse(
|
||||
{
|
||||
'message': translator.get('invalid_access_token', __package__)
|
||||
'message': translator.get('invalid_access_token', 'facefusion.apis')
|
||||
}, status_code = HTTP_426_UPGRADE_REQUIRED)
|
||||
|
||||
return await response(scope, receive, send)
|
||||
|
||||
response = JSONResponse(
|
||||
{
|
||||
'message': translator.get('invalid_access_token', __package__)
|
||||
'message': translator.get('invalid_access_token', 'facefusion.apis')
|
||||
}, status_code = HTTP_401_UNAUTHORIZED)
|
||||
|
||||
return await response(scope, receive, send)
|
||||
@@ -0,0 +1,80 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import args_store, session_manager, state_manager, translator
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.endpoints.session import extract_access_token
|
||||
|
||||
|
||||
async def get_state(request : Request) -> JSONResponse:
|
||||
api_args = args_store.filter_api_args(state_manager.get_state())
|
||||
return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def set_state(request : Request) -> JSONResponse:
|
||||
action = request.query_params.get('action')
|
||||
asset_type = request.query_params.get('asset_type')
|
||||
|
||||
if action == 'select' and asset_type == 'source':
|
||||
return await select_source(request)
|
||||
|
||||
if action == 'select' and asset_type == 'target':
|
||||
return await select_target(request)
|
||||
|
||||
body = await request.json()
|
||||
api_args = args_store.get_api_args()
|
||||
|
||||
for key, value in body.items():
|
||||
if key in api_args:
|
||||
state_manager.set_item(key, value)
|
||||
|
||||
__api_args__ = args_store.filter_api_args(state_manager.get_state())
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def select_source(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
asset_ids = body.get('asset_ids')
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if isinstance(asset_ids, list) and session_id:
|
||||
source_paths = []
|
||||
|
||||
for asset_id in asset_ids:
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
if asset:
|
||||
source_paths.append(asset.get('path'))
|
||||
|
||||
state_manager.set_item('source_paths', source_paths)
|
||||
|
||||
__api_args__ = args_store.filter_api_args(state_manager.get_state())
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('source_asset_not_found', 'facefusion.apis')
|
||||
}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
async def select_target(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
asset_id = body.get('asset_id')
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if isinstance(asset_id, str) and session_id:
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
if asset:
|
||||
state_manager.set_item('target_path', asset.get('path'))
|
||||
|
||||
__api_args__ = args_store.filter_api_args(state_manager.get_state())
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('target_asset_not_found', 'facefusion.apis')
|
||||
}, status_code = HTTP_404_NOT_FOUND)
|
||||
@@ -0,0 +1,14 @@
|
||||
from facefusion.types import Locales
|
||||
|
||||
LOCALES : Locales =\
|
||||
{
|
||||
'en':
|
||||
{
|
||||
'ok': 'ok',
|
||||
'something_went_wrong': 'something went wrong',
|
||||
'invalid_access_token': 'invalid access token',
|
||||
'invalid_refresh_token': 'invalid refresh token',
|
||||
'source_asset_not_found': 'source asset not found',
|
||||
'target_asset_not_found': 'target asset not found'
|
||||
}
|
||||
}
|
||||
+21
-20
@@ -9,8 +9,10 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
from facefusion import asset_store, logger
|
||||
from facefusion import logger
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.choices import audio_formats
|
||||
from facefusion.session_context import get_session_id
|
||||
|
||||
|
||||
def resolve_image_urls(url : str) -> List[str]:
|
||||
@@ -66,12 +68,14 @@ def download_images_from_url(url : str, asset_type : str) -> List[str]:
|
||||
gdl_job = gallery_job.DownloadJob(url)
|
||||
gdl_job.run()
|
||||
|
||||
session_id = get_session_id()
|
||||
for root, dirs, files in os.walk(output_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
asset_id = asset_store.register(asset_type, file_path, filename)
|
||||
asset_ids.append(asset_id)
|
||||
logger.info(f'Registered image as asset {asset_id}', __name__)
|
||||
asset = asset_store.create_asset(session_id, asset_type, file_path)
|
||||
if asset:
|
||||
asset_ids.append(asset.get('id'))
|
||||
logger.info(f'Registered image as asset {asset.get("id")}', __name__)
|
||||
|
||||
break
|
||||
|
||||
@@ -108,9 +112,11 @@ def download_images_from_url(url : str, asset_type : str) -> List[str]:
|
||||
for chunk in response.iter_bytes(chunk_size = 8192):
|
||||
f.write(chunk)
|
||||
|
||||
asset_id = asset_store.register(asset_type, file_path, filename)
|
||||
asset_ids.append(asset_id)
|
||||
logger.info(f'Downloaded and registered image as asset {asset_id}', __name__)
|
||||
session_id = get_session_id()
|
||||
asset = asset_store.create_asset(session_id, asset_type, file_path)
|
||||
if asset:
|
||||
asset_ids.append(asset.get('id'))
|
||||
logger.info(f'Downloaded and registered image as asset {asset.get("id")}', __name__)
|
||||
|
||||
return asset_ids
|
||||
|
||||
@@ -138,9 +144,11 @@ def download_audio_from_url(url : str, asset_type : str) -> List[str]:
|
||||
for chunk in response.iter_bytes(chunk_size = 8192):
|
||||
f.write(chunk)
|
||||
|
||||
asset_id = asset_store.register(asset_type, file_path, filename)
|
||||
asset_ids.append(asset_id)
|
||||
logger.info(f'Downloaded and registered audio as asset {asset_id}', __name__)
|
||||
session_id = get_session_id()
|
||||
asset = asset_store.create_asset(session_id, asset_type, file_path)
|
||||
if asset:
|
||||
asset_ids.append(asset.get('id'))
|
||||
logger.info(f'Downloaded and registered audio as asset {asset.get("id")}', __name__)
|
||||
|
||||
return asset_ids
|
||||
|
||||
@@ -352,16 +360,9 @@ async def remote(request : Request) -> JSONResponse:
|
||||
total_frames = int(duration * fps)
|
||||
logger.info(f'Calculated total frames: {total_frames} ({duration}s * {fps} fps)', __name__)
|
||||
|
||||
filename = os.path.basename(downloaded_file)
|
||||
metadata =\
|
||||
{
|
||||
'frame_total': total_frames,
|
||||
'fps': fps,
|
||||
'resolution': (width, height) if width and height else None,
|
||||
'duration': duration
|
||||
}
|
||||
|
||||
asset_id = asset_store.register(asset_type, downloaded_file, filename, metadata)
|
||||
session_id = get_session_id()
|
||||
asset = asset_store.create_asset(session_id, asset_type, downloaded_file)
|
||||
asset_id = asset.get('id') if asset else None
|
||||
logger.info(f'Video downloaded and registered as asset {asset_id}', __name__)
|
||||
|
||||
response_data =\
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import args_store, asset_store, logger, state_manager
|
||||
|
||||
|
||||
async def get_state(request : Request) -> JSONResponse:
|
||||
api_args = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(api_args), status_code = HTTP_200_OK)
|
||||
|
||||
|
||||
async def set_state(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
action = request.query_params.get('action')
|
||||
|
||||
if action == 'select':
|
||||
asset_type = request.query_params.get('asset_type')
|
||||
|
||||
if not asset_type:
|
||||
return JSONResponse({'message': 'Missing required query parameter: asset_type'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if asset_type not in ['source', 'target']:
|
||||
return JSONResponse({'message': 'Invalid asset_type. Must be "source" or "target"'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if asset_type == 'source':
|
||||
asset_ids = body.get('asset_ids', [])
|
||||
|
||||
if not isinstance(asset_ids, list):
|
||||
return JSONResponse({'message': 'asset_ids must be an array'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not asset_ids:
|
||||
return JSONResponse({'message': 'asset_ids cannot be empty'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
source_paths = []
|
||||
for asset_id in asset_ids:
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
if not asset:
|
||||
return JSONResponse({'message': f'Source asset not found: {asset_id}'}, status_code = HTTP_404_NOT_FOUND)
|
||||
source_paths.append(asset['path'])
|
||||
|
||||
state_manager.set_item('source_paths', source_paths)
|
||||
|
||||
__api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
if asset_type == 'target':
|
||||
asset_id = body.get('asset_id')
|
||||
|
||||
if not asset_id:
|
||||
return JSONResponse({'message': 'Missing required field: asset_id'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not isinstance(asset_id, str):
|
||||
return JSONResponse({'message': 'asset_id must be a string'}, status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
if not asset:
|
||||
return JSONResponse({'message': f'Target asset not found: {asset_id}'}, status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
state_manager.set_item('target_path', asset['path'])
|
||||
|
||||
__api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
|
||||
api_args = args_store.get_api_args()
|
||||
logger.info(f'[State] Normal update - body keys: {list(body.keys())}', __name__)
|
||||
|
||||
for key, value in body.items():
|
||||
if key in api_args:
|
||||
state_manager.set_item(key, value)
|
||||
logger.debug(f'[State] Set {key} = {value}', __name__)
|
||||
else:
|
||||
logger.warn(f'[State] Skipped {key} (not in api_args)', __name__)
|
||||
|
||||
__api_args__ = args_store.filter_api_args(state_manager.get_state()) #type:ignore[arg-type]
|
||||
return JSONResponse(state_manager.collect_state(__api_args__), status_code = HTTP_200_OK)
|
||||
@@ -9,7 +9,7 @@ from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
|
||||
|
||||
from facefusion import logger
|
||||
from facefusion.asset_store import get_asset
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.filesystem import is_video
|
||||
from facefusion.video_manager import get_video_capture
|
||||
from facefusion.vision import fit_contain_frame
|
||||
@@ -81,14 +81,11 @@ async def get_timeline(request: Request) -> JSONResponse:
|
||||
if asset_id and not target_path:
|
||||
from facefusion.session_context import get_session_id
|
||||
|
||||
asset = get_asset(asset_id)
|
||||
session_id = get_session_id()
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
if not asset:
|
||||
return JSONResponse({'message': f'Asset not found: {asset_id}'}, status_code=HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Verify asset belongs to current session (security)
|
||||
if asset.get('session_id') != get_session_id():
|
||||
return JSONResponse({'message': 'Asset not found'}, status_code=HTTP_400_BAD_REQUEST)
|
||||
|
||||
target_path = asset.get('path')
|
||||
if not target_path:
|
||||
return JSONResponse({'message': 'Asset has no path'}, status_code=HTTP_400_BAD_REQUEST)
|
||||
|
||||
@@ -162,6 +162,70 @@ AudioTypeSet : TypeAlias = Dict[AudioFormat, str]
|
||||
ImageTypeSet : TypeAlias = Dict[ImageFormat, str]
|
||||
VideoTypeSet : TypeAlias = Dict[VideoFormat, str]
|
||||
|
||||
AssetId : TypeAlias = str
|
||||
AssetType = Literal['source', 'target']
|
||||
AudioMetadata = TypedDict('AudioMetadata',
|
||||
{
|
||||
'duration' : Duration,
|
||||
'sample_rate': int,
|
||||
'frame_total': int,
|
||||
'channels': int,
|
||||
'format': str
|
||||
})
|
||||
ImageMetadata = TypedDict('ImageMetadata',
|
||||
{
|
||||
'resolution' : Resolution
|
||||
})
|
||||
VideoMetadata = TypedDict('VideoMetadata',
|
||||
{
|
||||
'duration' : Duration,
|
||||
'frame_total' : int,
|
||||
'fps' : Fps,
|
||||
'resolution' : Resolution
|
||||
})
|
||||
AudioAsset = TypedDict('AudioAsset',
|
||||
{
|
||||
'id' : AssetId,
|
||||
'created_at' : datetime,
|
||||
'expires_at' : datetime,
|
||||
'type' : AssetType,
|
||||
'media' : Literal['audio'],
|
||||
'name' : str,
|
||||
'format' : AudioFormat,
|
||||
'size' : int,
|
||||
'path' : str,
|
||||
'metadata' : AudioMetadata
|
||||
})
|
||||
ImageAsset = TypedDict('ImageAsset',
|
||||
{
|
||||
'id' : AssetId,
|
||||
'created_at' : datetime,
|
||||
'expires_at' : datetime,
|
||||
'type' : AssetType,
|
||||
'media' : Literal['image'],
|
||||
'name' : str,
|
||||
'format' : ImageFormat,
|
||||
'size' : int,
|
||||
'path' : str,
|
||||
'metadata' : ImageMetadata
|
||||
})
|
||||
VideoAsset = TypedDict('VideoAsset',
|
||||
{
|
||||
'id' : AssetId,
|
||||
'created_at' : datetime,
|
||||
'expires_at' : datetime,
|
||||
'type' : AssetType,
|
||||
'media' : Literal['video'],
|
||||
'name' : str,
|
||||
'format' : VideoFormat,
|
||||
'size' : int,
|
||||
'path' : str,
|
||||
'metadata' : VideoMetadata
|
||||
})
|
||||
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
|
||||
AssetSet : TypeAlias = Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]
|
||||
AssetStore : TypeAlias = Dict[SessionId, AssetSet]
|
||||
|
||||
AudioEncoder = Literal['flac', 'aac', 'libmp3lame', 'libopus', 'libvorbis', 'pcm_s16le', 'pcm_s32le']
|
||||
VideoEncoder = Literal['libx264', 'libx264rgb', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf', 'h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox', 'rawvideo']
|
||||
EncoderSet = TypedDict('EncoderSet',
|
||||
|
||||
+109
-159
@@ -2,24 +2,24 @@ import os
|
||||
import tempfile
|
||||
from typing import Iterator
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from facefusion import asset_store, session_manager, state_manager
|
||||
from facefusion.session_context import clear_session_id, set_session_id
|
||||
from facefusion.apis import asset_store
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function', autouse = True)
|
||||
def before_each() -> None:
|
||||
session_manager.SESSIONS.clear()
|
||||
state_manager.clear_item('asset_registry')
|
||||
clear_session_id()
|
||||
asset_store.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'function')
|
||||
def temp_file() -> Iterator[str]:
|
||||
fd, path = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.write(fd, b'test file content')
|
||||
os.close(fd)
|
||||
image = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path, image)
|
||||
yield path
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
@@ -27,221 +27,171 @@ def temp_file() -> Iterator[str]:
|
||||
|
||||
@pytest.fixture(scope = 'function')
|
||||
def session_id() -> str:
|
||||
test_session_id = 'test-session-123'
|
||||
set_session_id(test_session_id)
|
||||
return test_session_id
|
||||
return 'test-session-123'
|
||||
|
||||
|
||||
def test_register_source_asset(temp_file : str, session_id : str) -> None:
|
||||
asset_id = asset_store.register('source', temp_file, 'test.jpg')
|
||||
def test_create_source_asset(temp_file : str, session_id : str) -> None:
|
||||
asset = asset_store.create_asset(session_id, 'source', temp_file)
|
||||
|
||||
assert isinstance(asset_id, str)
|
||||
assert len(asset_id) == 36
|
||||
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
assert asset is not None
|
||||
assert asset.get('id') == asset_id
|
||||
assert asset.get('session_id') == session_id
|
||||
assert isinstance(asset.get('id'), str)
|
||||
assert len(asset.get('id')) == 36
|
||||
assert asset.get('type') == 'source'
|
||||
assert asset.get('path') == temp_file
|
||||
assert asset.get('filename') == 'test.jpg'
|
||||
assert asset.get('size') > 0
|
||||
assert asset.get('created_at')
|
||||
|
||||
|
||||
def test_register_target_asset(temp_file : str, session_id : str) -> None:
|
||||
asset_id = asset_store.register('target', temp_file, 'video.mp4')
|
||||
def test_create_target_asset(temp_file : str, session_id : str) -> None:
|
||||
asset = asset_store.create_asset(session_id, 'target', temp_file)
|
||||
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
assert asset is not None
|
||||
assert asset.get('type') == 'target'
|
||||
assert asset.get('filename') == 'video.mp4'
|
||||
|
||||
|
||||
def test_register_output_asset(temp_file : str, session_id : str) -> None:
|
||||
metadata = {'fps': 30, 'resolution': [1920, 1080]}
|
||||
asset_id = asset_store.register('output', temp_file, 'output.mp4', metadata)
|
||||
def test_get_asset(temp_file : str, session_id : str) -> None:
|
||||
created_asset = asset_store.create_asset(session_id, 'source', temp_file)
|
||||
asset_id = created_asset.get('id')
|
||||
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
assert asset.get('type') == 'output'
|
||||
assert asset.get('metadata') == metadata
|
||||
|
||||
|
||||
def test_register_invalid_type(temp_file : str, session_id : str) -> None:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
asset_store.register('invalid_type', temp_file, 'test.jpg')
|
||||
assert "Invalid asset_type" in str(exc.value)
|
||||
|
||||
|
||||
def test_register_without_session() -> None:
|
||||
fd, path = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
asset_store.register('source', path, 'test.jpg')
|
||||
assert "No active session" in str(exc.value)
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
|
||||
def test_register_without_filename(temp_file : str, session_id : str) -> None:
|
||||
asset_id = asset_store.register('source', temp_file)
|
||||
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
assert asset.get('filename') == os.path.basename(temp_file)
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
assert asset is not None
|
||||
assert asset.get('id') == asset_id
|
||||
assert asset.get('type') == 'source'
|
||||
|
||||
|
||||
def test_get_asset_not_found(session_id : str) -> None:
|
||||
asset = asset_store.get_asset('non-existent-id')
|
||||
asset = asset_store.get_asset(session_id, 'non-existent-id')
|
||||
assert asset is None
|
||||
|
||||
|
||||
def test_list_assets_empty(session_id : str) -> None:
|
||||
assets = asset_store.list_assets()
|
||||
assert assets == []
|
||||
def test_get_asset_wrong_session(temp_file : str, session_id : str) -> None:
|
||||
created_asset = asset_store.create_asset(session_id, 'source', temp_file)
|
||||
asset_id = created_asset.get('id')
|
||||
|
||||
asset = asset_store.get_asset('different-session', asset_id)
|
||||
assert asset is None
|
||||
|
||||
|
||||
def test_list_assets_with_multiple(temp_file : str, session_id : str) -> None:
|
||||
def test_get_assets_empty(session_id : str) -> None:
|
||||
assets = asset_store.get_assets(session_id)
|
||||
assert assets is None
|
||||
|
||||
|
||||
def test_get_assets_with_multiple(temp_file : str, session_id : str) -> None:
|
||||
fd1, path1 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.write(fd1, b'content 1')
|
||||
os.close(fd1)
|
||||
image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path1, image1)
|
||||
|
||||
fd2, path2 = tempfile.mkstemp(suffix = '.mp4')
|
||||
os.write(fd2, b'content 2')
|
||||
fd2, path2 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.close(fd2)
|
||||
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path2, image2)
|
||||
|
||||
try:
|
||||
asset_store.register('source', path1, 'source1.jpg')
|
||||
asset_store.register('source', path2, 'source2.jpg')
|
||||
asset_store.register('target', temp_file, 'target.mp4')
|
||||
asset_store.create_asset(session_id, 'source', path1)
|
||||
asset_store.create_asset(session_id, 'source', path2)
|
||||
asset_store.create_asset(session_id, 'target', temp_file)
|
||||
|
||||
assets = asset_store.list_assets()
|
||||
assets = asset_store.get_assets(session_id)
|
||||
assert assets is not None
|
||||
assert len(assets) == 3
|
||||
finally:
|
||||
os.remove(path1)
|
||||
os.remove(path2)
|
||||
if os.path.exists(path1):
|
||||
os.remove(path1)
|
||||
if os.path.exists(path2):
|
||||
os.remove(path2)
|
||||
|
||||
|
||||
def test_list_assets_filter_by_type(temp_file : str, session_id : str) -> None:
|
||||
fd, path = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.write(fd, b'content')
|
||||
os.close(fd)
|
||||
|
||||
try:
|
||||
asset_store.register('source', path, 'source.jpg')
|
||||
asset_store.register('target', temp_file, 'target.mp4')
|
||||
|
||||
source_assets = asset_store.list_assets('source')
|
||||
assert len(source_assets) == 1
|
||||
assert source_assets[0].get('type') == 'source'
|
||||
|
||||
target_assets = asset_store.list_assets('target')
|
||||
assert len(target_assets) == 1
|
||||
assert target_assets[0].get('type') == 'target'
|
||||
|
||||
output_assets = asset_store.list_assets('output')
|
||||
assert len(output_assets) == 0
|
||||
finally:
|
||||
os.remove(path)
|
||||
|
||||
|
||||
def test_list_assets_invalid_type(session_id : str) -> None:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
asset_store.list_assets('invalid_type')
|
||||
assert "Invalid asset_type" in str(exc.value)
|
||||
|
||||
|
||||
def test_list_assets_session_scoped(temp_file : str) -> None:
|
||||
def test_get_assets_session_scoped(temp_file : str) -> None:
|
||||
session1_id = 'session-1'
|
||||
set_session_id(session1_id)
|
||||
asset1_id = asset_store.register('source', temp_file, 'file1.jpg')
|
||||
asset1 = asset_store.create_asset(session1_id, 'source', temp_file)
|
||||
asset1_id = asset1.get('id')
|
||||
|
||||
session2_id = 'session-2'
|
||||
set_session_id(session2_id)
|
||||
|
||||
fd, path2 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.write(fd, b'content 2')
|
||||
os.close(fd)
|
||||
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path2, image2)
|
||||
|
||||
try:
|
||||
asset2_id = asset_store.register('source', path2, 'file2.jpg')
|
||||
asset2 = asset_store.create_asset(session2_id, 'source', path2)
|
||||
asset2_id = asset2.get('id')
|
||||
|
||||
assets_session2 = asset_store.list_assets()
|
||||
assets_session2 = asset_store.get_assets(session2_id)
|
||||
assert assets_session2 is not None
|
||||
assert len(assets_session2) == 1
|
||||
assert assets_session2[0].get('id') == asset2_id
|
||||
assert asset2_id in assets_session2
|
||||
|
||||
set_session_id(session1_id)
|
||||
assets_session1 = asset_store.list_assets()
|
||||
assets_session1 = asset_store.get_assets(session1_id)
|
||||
assert assets_session1 is not None
|
||||
assert len(assets_session1) == 1
|
||||
assert assets_session1[0].get('id') == asset1_id
|
||||
assert asset1_id in assets_session1
|
||||
finally:
|
||||
os.remove(path2)
|
||||
if os.path.exists(path2):
|
||||
os.remove(path2)
|
||||
|
||||
|
||||
def test_delete_asset(temp_file : str, session_id : str) -> None:
|
||||
asset_id = asset_store.register('source', temp_file, 'test.jpg')
|
||||
|
||||
assert os.path.exists(temp_file)
|
||||
|
||||
success = asset_store.delete_asset(asset_id)
|
||||
assert success is True
|
||||
|
||||
assert not os.path.exists(temp_file)
|
||||
|
||||
asset = asset_store.get_asset(asset_id)
|
||||
assert asset is None
|
||||
|
||||
|
||||
def test_delete_asset_not_found(session_id : str) -> None:
|
||||
success = asset_store.delete_asset('non-existent-id')
|
||||
assert success is False
|
||||
|
||||
|
||||
def test_cleanup_session_assets(session_id : str) -> None:
|
||||
def test_delete_assets(session_id : str) -> None:
|
||||
fd1, path1 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.write(fd1, b'content 1')
|
||||
os.close(fd1)
|
||||
image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path1, image1)
|
||||
|
||||
fd2, path2 = tempfile.mkstemp(suffix = '.mp4')
|
||||
os.write(fd2, b'content 2')
|
||||
fd2, path2 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.close(fd2)
|
||||
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path2, image2)
|
||||
|
||||
asset_id1 = asset_store.register('source', path1, 'source.jpg')
|
||||
asset_id2 = asset_store.register('target', path2, 'target.mp4')
|
||||
try:
|
||||
asset1 = asset_store.create_asset(session_id, 'source', path1)
|
||||
asset2 = asset_store.create_asset(session_id, 'source', path2)
|
||||
asset1_id = asset1.get('id')
|
||||
asset2_id = asset2.get('id')
|
||||
|
||||
assert os.path.exists(path1)
|
||||
assert os.path.exists(path2)
|
||||
asset_store.delete_assets(session_id, [asset1_id])
|
||||
|
||||
asset_store.cleanup_session_assets(session_id)
|
||||
|
||||
assert not os.path.exists(path1)
|
||||
assert not os.path.exists(path2)
|
||||
|
||||
assert asset_store.get_asset(asset_id1) is None
|
||||
assert asset_store.get_asset(asset_id2) is None
|
||||
assets = asset_store.get_assets(session_id)
|
||||
assert assets is not None
|
||||
assert len(assets) == 1
|
||||
assert asset2_id in assets
|
||||
assert asset1_id not in assets
|
||||
finally:
|
||||
if os.path.exists(path1):
|
||||
os.remove(path1)
|
||||
if os.path.exists(path2):
|
||||
os.remove(path2)
|
||||
|
||||
|
||||
def test_cleanup_session_assets_only_affects_target_session(temp_file : str) -> None:
|
||||
session1_id = 'session-1'
|
||||
set_session_id(session1_id)
|
||||
def test_delete_assets_not_found(session_id : str) -> None:
|
||||
asset_store.delete_assets(session_id, ['non-existent-id'])
|
||||
|
||||
fd, path1 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.write(fd, b'content 1')
|
||||
os.close(fd)
|
||||
|
||||
asset1_id = asset_store.register('source', path1, 'file1.jpg')
|
||||
def test_clear() -> None:
|
||||
fd1, path1 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.close(fd1)
|
||||
image1 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path1, image1)
|
||||
|
||||
session2_id = 'session-2'
|
||||
set_session_id(session2_id)
|
||||
asset2_id = asset_store.register('source', temp_file, 'file2.jpg')
|
||||
fd2, path2 = tempfile.mkstemp(suffix = '.jpg')
|
||||
os.close(fd2)
|
||||
image2 = numpy.zeros((100, 100, 3), dtype = numpy.uint8)
|
||||
cv2.imwrite(path2, image2)
|
||||
|
||||
asset_store.cleanup_session_assets(session1_id)
|
||||
try:
|
||||
session1_id = 'session-1'
|
||||
session2_id = 'session-2'
|
||||
|
||||
assert not os.path.exists(path1)
|
||||
assert os.path.exists(temp_file)
|
||||
asset_store.create_asset(session1_id, 'source', path1)
|
||||
asset_store.create_asset(session2_id, 'source', path2)
|
||||
|
||||
set_session_id(session1_id)
|
||||
assert asset_store.get_asset(asset1_id) is None
|
||||
asset_store.clear()
|
||||
|
||||
set_session_id(session2_id)
|
||||
assert asset_store.get_asset(asset2_id) is not None
|
||||
assert asset_store.get_assets(session1_id) is None
|
||||
assert asset_store.get_assets(session2_id) is None
|
||||
finally:
|
||||
if os.path.exists(path1):
|
||||
os.remove(path1)
|
||||
if os.path.exists(path2):
|
||||
os.remove(path2)
|
||||
|
||||
Reference in New Issue
Block a user