mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-29 04:55:57 +02:00
add missing endpoints
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from facefusion.apis.asset_helper import detect_media_type, extract_audio_metadata, extract_image_metadata, extract_video_metadata
|
||||
from facefusion.filesystem import get_file_format, get_file_name
|
||||
from facefusion.types import AssetId, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
|
||||
from facefusion.types import AssetId, AssetSet, AssetStore, AssetType, AudioAsset, AudioFormat, ImageAsset, ImageFormat, SessionId, VideoAsset, VideoFormat
|
||||
|
||||
ASSET_STORE : AssetStore = {}
|
||||
|
||||
@@ -70,9 +70,21 @@ def create_asset(session_id : SessionId, asset_type : AssetType, asset_path : st
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
|
||||
|
||||
def get_assets(session_id : SessionId) -> Optional[AssetSet]:
|
||||
return ASSET_STORE.get(session_id)
|
||||
|
||||
|
||||
def get_asset(session_id : SessionId, asset_id : AssetId) -> Optional[AudioAsset | ImageAsset | VideoAsset]:
|
||||
if session_id in ASSET_STORE:
|
||||
return ASSET_STORE[session_id].get(asset_id)
|
||||
return ASSET_STORE.get(session_id).get(asset_id)
|
||||
return None
|
||||
|
||||
|
||||
def delete_assets(session_id : SessionId, asset_ids : List[AssetId]) -> None:
|
||||
if session_id in ASSET_STORE:
|
||||
for asset_id in asset_ids:
|
||||
if asset_id in ASSET_STORE.get(session_id):
|
||||
del ASSET_STORE[session_id][asset_id]
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Route, WebSocketRoute
|
||||
|
||||
from facefusion.apis.endpoints.assets import upload_asset
|
||||
from facefusion.apis.endpoints.assets import delete_assets, get_asset, get_assets, upload_asset
|
||||
from facefusion.apis.endpoints.ping import websocket_ping
|
||||
from facefusion.apis.endpoints.session import create_session, create_session_guard, destroy_session, get_session, refresh_session
|
||||
from facefusion.apis.endpoints.state import get_state, set_state
|
||||
@@ -19,7 +19,10 @@ def create_api() -> Starlette:
|
||||
Route('/session', destroy_session, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
Route('/state', get_state, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/state', set_state, methods = [ 'PUT' ], middleware = [ session_guard ]),
|
||||
Route('/assets', get_assets, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/assets', upload_asset, methods = [ 'POST' ], middleware = [ session_guard ]),
|
||||
Route('/assets/{asset_id}', get_asset, methods = [ 'GET' ], middleware = [ session_guard ]),
|
||||
Route('/assets', delete_assets, methods = [ 'DELETE' ], middleware = [ session_guard ]),
|
||||
WebSocketRoute('/ping', websocket_ping, middleware = [ session_guard ])
|
||||
]
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List
|
||||
from starlette.datastructures import UploadFile
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
|
||||
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
|
||||
|
||||
from facefusion import ffmpeg, process_manager, session_manager, state_manager
|
||||
from facefusion.apis import asset_store
|
||||
@@ -22,7 +22,7 @@ async def upload_asset(request : Request) -> Response:
|
||||
if session_id and asset_type in [ 'source', 'target' ]:
|
||||
form = await request.form()
|
||||
upload_files = form.getlist('file')
|
||||
asset_paths = await save_asset_files(upload_files)
|
||||
asset_paths = await save_asset_files(upload_files) # type: ignore[arg-type]
|
||||
|
||||
if asset_paths:
|
||||
asset_ids : List[str] = []
|
||||
@@ -76,3 +76,78 @@ async def save_asset_files(upload_files : List[UploadFile]) -> List[str]:
|
||||
remove_file(temp_file.name)
|
||||
|
||||
return asset_paths
|
||||
|
||||
|
||||
async def get_assets(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if session_id:
|
||||
asset_set = asset_store.get_assets(session_id)
|
||||
assets = []
|
||||
|
||||
if asset_set:
|
||||
for asset in asset_set.values():
|
||||
assets.append(
|
||||
{
|
||||
'id': asset.get('id'),
|
||||
'created_at': asset.get('created_at').isoformat(),
|
||||
'expires_at': asset.get('expires_at').isoformat(),
|
||||
'type': asset.get('type'),
|
||||
'media': asset.get('media'),
|
||||
'name': asset.get('name'),
|
||||
'format': asset.get('format'),
|
||||
'size': asset.get('size'),
|
||||
'metadata': asset.get('metadata')
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'assets': assets
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return Response(status_code = HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
||||
async def get_asset(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
asset_id = request.path_params.get('asset_id')
|
||||
|
||||
if session_id and asset_id:
|
||||
asset = asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
if asset:
|
||||
return JSONResponse(
|
||||
{
|
||||
'id': asset.get('id'),
|
||||
'created_at': asset.get('created_at').isoformat(),
|
||||
'expires_at': asset.get('expires_at').isoformat(),
|
||||
'type': asset.get('type'),
|
||||
'media': asset.get('media'),
|
||||
'name': asset.get('name'),
|
||||
'format': asset.get('format'),
|
||||
'size': asset.get('size'),
|
||||
'metadata': asset.get('metadata')
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
async def delete_assets(request : Request) -> Response:
|
||||
access_token = extract_access_token(request.scope)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
body = await request.json()
|
||||
asset_ids = body.get('asset_ids')
|
||||
|
||||
if session_id and asset_ids:
|
||||
asset_set = asset_store.get_assets(session_id)
|
||||
|
||||
if asset_set:
|
||||
for asset_id in asset_ids:
|
||||
if asset_id in asset_set:
|
||||
remove_file(asset_set.get(asset_id).get('path'))
|
||||
asset_store.delete_assets(session_id, asset_ids)
|
||||
return Response(status_code = HTTP_200_OK)
|
||||
|
||||
return Response(status_code = HTTP_404_NOT_FOUND)
|
||||
|
||||
+2
-1
@@ -232,7 +232,8 @@ VideoAsset = TypedDict('VideoAsset',
|
||||
})
|
||||
|
||||
AssetMetadata : TypeAlias = AudioMetadata | ImageMetadata | VideoMetadata
|
||||
AssetStore : TypeAlias = Dict[SessionId, Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]]
|
||||
AssetSet : TypeAlias = Dict[AssetId, AudioAsset | ImageAsset | VideoAsset]
|
||||
AssetStore : TypeAlias = Dict[SessionId, AssetSet]
|
||||
|
||||
BenchmarkMode = Literal['warm', 'cold']
|
||||
BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p']
|
||||
|
||||
+170
-1
@@ -4,7 +4,7 @@ from typing import Iterator
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from facefusion import metadata, process_manager, session_manager, state_manager
|
||||
from facefusion import metadata, session_manager, state_manager
|
||||
from facefusion.apis import asset_store
|
||||
from facefusion.apis.core import create_api
|
||||
from facefusion.download import conditional_download
|
||||
@@ -101,3 +101,172 @@ def test_upload_asset(test_client : TestClient) -> None:
|
||||
})
|
||||
|
||||
assert upload_response.status_code == 400
|
||||
|
||||
|
||||
def test_get_assets(test_client : TestClient) -> None:
|
||||
get_response = test_client.get('/assets')
|
||||
|
||||
assert get_response.status_code == 401
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
access_token = create_session_body.get('access_token')
|
||||
|
||||
get_response = test_client.get('/assets', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
get_body = get_response.json()
|
||||
|
||||
assert get_body.get('assets') == []
|
||||
|
||||
assert get_response.status_code == 200
|
||||
|
||||
source_path = get_test_example_file('source.jpg')
|
||||
target_path = get_test_example_file('target-240p.mp4')
|
||||
|
||||
with open(source_path, 'rb') as source_file, open(target_path, 'rb') as target_file:
|
||||
source_content = source_file.read()
|
||||
target_content = target_file.read()
|
||||
test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
}, files =
|
||||
[
|
||||
('file', ('source.jpg', source_content, 'image/jpeg')),
|
||||
('file', ('target.mp4', target_content, 'video/mp4'))
|
||||
])
|
||||
|
||||
get_response = test_client.get('/assets', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
get_body = get_response.json()
|
||||
assets = get_body.get('assets')
|
||||
|
||||
assert len(assets) == 2
|
||||
assert assets[0].get('media') == 'image'
|
||||
assert assets[1].get('media') == 'video'
|
||||
|
||||
assert get_response.status_code == 200
|
||||
|
||||
|
||||
def test_get_asset(test_client : TestClient) -> None:
|
||||
get_response = test_client.get('invalid')
|
||||
|
||||
assert get_response.status_code == 404
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
access_token = create_session_body.get('access_token')
|
||||
|
||||
source_path = get_test_example_file('source.jpg')
|
||||
|
||||
with open(source_path, 'rb') as source_file:
|
||||
source_content = source_file.read()
|
||||
upload_response = test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
}, files =
|
||||
[
|
||||
('file', ('source.jpg', source_content, 'image/jpeg'))
|
||||
])
|
||||
upload_body = upload_response.json()
|
||||
asset_id = upload_body.get('asset_ids')[0]
|
||||
|
||||
second_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
second_session_body = second_session_response.json()
|
||||
second_access_token = second_session_body.get('access_token')
|
||||
|
||||
get_response = test_client.get('/assets/' + asset_id, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + second_access_token
|
||||
})
|
||||
|
||||
assert get_response.status_code == 404
|
||||
|
||||
get_response = test_client.get('/assets/' + asset_id, headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
})
|
||||
get_body = get_response.json()
|
||||
|
||||
assert get_body.get('id') == asset_id
|
||||
assert get_body.get('type') == 'source'
|
||||
assert get_body.get('media') == 'image'
|
||||
assert get_body.get('format') == 'jpeg'
|
||||
assert get_body.get('metadata').get('resolution') == [ 1024, 1024 ]
|
||||
|
||||
assert get_response.status_code == 200
|
||||
|
||||
|
||||
def test_delete_assets(test_client : TestClient) -> None:
|
||||
create_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
access_token = create_session_body.get('access_token')
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
source_path = get_test_example_file('source.jpg')
|
||||
|
||||
with open(source_path, 'rb') as source_file:
|
||||
source_content = source_file.read()
|
||||
upload_response = test_client.post('/assets?type=source', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
}, files =
|
||||
[
|
||||
('file', ('source.jpg', source_content, 'image/jpeg'))
|
||||
])
|
||||
upload_body = upload_response.json()
|
||||
asset_id = upload_body.get('asset_ids')[0]
|
||||
|
||||
assert asset_store.get_asset(session_id, asset_id)
|
||||
|
||||
second_session_response = test_client.post('/session', json =
|
||||
{
|
||||
'client_version': metadata.get('version')
|
||||
})
|
||||
second_session_body = second_session_response.json()
|
||||
second_access_token = second_session_body.get('access_token')
|
||||
|
||||
delete_response = test_client.request('DELETE', '/assets', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + second_access_token
|
||||
}, json =
|
||||
{
|
||||
'asset_ids': [ asset_id ]
|
||||
})
|
||||
|
||||
assert delete_response.status_code == 404
|
||||
|
||||
delete_response = test_client.request('DELETE', '/assets', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
}, json =
|
||||
{
|
||||
'asset_ids': [ asset_id ]
|
||||
})
|
||||
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
delete_response = test_client.request('DELETE', '/assets', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
}, json =
|
||||
{
|
||||
'asset_ids': [ asset_id ]
|
||||
})
|
||||
|
||||
assert delete_response.status_code == 404
|
||||
|
||||
Reference in New Issue
Block a user