mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 17:36:16 +02:00
Feat/session context (#993)
* Add simple session context * Add simple session context
This commit is contained in:
@@ -8,7 +8,7 @@ from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED, HTTP_426_UPGRADE_REQUIRED
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from facefusion import session_manager, translator
|
||||
from facefusion import session_context, session_manager, translator
|
||||
from facefusion.types import Token
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ async def create_session(request : Request) -> JSONResponse:
|
||||
if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'):
|
||||
session_id = secrets.token_urlsafe(16)
|
||||
session = session_manager.create_session()
|
||||
session_context.set_session_id(session_id)
|
||||
session_manager.set_session(session_id, session)
|
||||
|
||||
return JSONResponse(
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.types import SessionId
|
||||
|
||||
SESSION_ID : ContextVar[Optional[SessionId]] = ContextVar('SESSION_ID', default = None)
|
||||
|
||||
|
||||
def set_session_id(session_id : SessionId) -> None:
|
||||
SESSION_ID.set(session_id)
|
||||
|
||||
|
||||
def get_session_id() -> Optional[SessionId]:
|
||||
return SESSION_ID.get()
|
||||
|
||||
|
||||
def clear_session_id() -> None:
|
||||
SESSION_ID.set(None)
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Union
|
||||
|
||||
from facefusion.app_context import detect_app_context
|
||||
from facefusion.processors.types import ProcessorState, ProcessorStateKey, ProcessorStateSet
|
||||
from facefusion.session_context import get_session_id
|
||||
from facefusion.types import Args, State, StateKey, StateSet
|
||||
|
||||
STATE_SET : Union[StateSet, ProcessorStateSet] =\
|
||||
@@ -45,7 +46,7 @@ def clear_item(key : Union[StateKey, ProcessorStateKey]) -> None:
|
||||
|
||||
def get_jobs_path() -> str:
|
||||
jobs_path = get_item('jobs_path')
|
||||
session_id = get_item('session_id')
|
||||
session_id = get_session_id()
|
||||
|
||||
if session_id:
|
||||
return os.path.join(jobs_path, session_id)
|
||||
@@ -54,7 +55,7 @@ def get_jobs_path() -> str:
|
||||
|
||||
def get_temp_path() -> str:
|
||||
temp_path = get_item('temp_path')
|
||||
session_id = get_item('session_id')
|
||||
session_id = get_session_id()
|
||||
|
||||
if session_id:
|
||||
return os.path.join(temp_path, session_id)
|
||||
|
||||
Reference in New Issue
Block a user