From dd772279edbabc55b063bb3bde663e50bdb3afd6 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Mon, 17 Nov 2025 11:09:47 +0100 Subject: [PATCH] Add session_id, make token size more reasonable (#983) * Add session_id, make token size more reasonable * Use more direct approach --- facefusion/apis/session.py | 38 +++++++++++++++++++---------------- facefusion/session_manager.py | 33 ++++++++++++++++++------------ facefusion/types.py | 1 + tests/test_api_session.py | 19 +++++++++--------- tests/test_session_manager.py | 26 ++++++++++++------------ 5 files changed, 65 insertions(+), 52 deletions(-) diff --git a/facefusion/apis/session.py b/facefusion/apis/session.py index d18946dc..30fbcb0b 100644 --- a/facefusion/apis/session.py +++ b/facefusion/apis/session.py @@ -1,4 +1,5 @@ import os +import secrets from typing import Optional from starlette.datastructures import Headers @@ -15,8 +16,9 @@ async def create_session(request : Request) -> JSONResponse: body = await request.json() if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'): + session_id = secrets.token_urlsafe(16) session = session_manager.create_session() - session_manager.set_session(session.get('access_token'), session) + session_manager.set_session(session_id, session) return JSONResponse( { @@ -34,9 +36,11 @@ async def get_session(request : Request) -> JSONResponse: access_token = extract_access_token(request.headers) if access_token: - session = session_manager.get_session(access_token) + session_id = session_manager.find_session_id(access_token) + + if session_id: + session = session_manager.get_session(session_id) - if session: return JSONResponse( { 'access_token': session.get('access_token'), @@ -54,16 +58,15 @@ async def get_session(request : Request) -> JSONResponse: async def refresh_session(request : Request) -> JSONResponse: body = await request.json() - for access_token, session in session_manager.SESSIONS.items(): + for session_id, session in session_manager.SESSIONS.items(): if session.get('refresh_token') == body.get('refresh_token'): - session_manager.clear_session(access_token) - session = session_manager.create_session() - session_manager.set_session(session.get('access_token'), session) + __session__ = session_manager.create_session() + session_manager.set_session(session_id, __session__) return JSONResponse( { - 'access_token': session.get('access_token'), - 'refresh_token': session.get('refresh_token') + 'access_token': __session__.get('access_token'), + 'refresh_token': __session__.get('refresh_token') }, status_code = HTTP_200_OK) return JSONResponse( @@ -76,10 +79,11 @@ async def destroy_session(request : Request) -> JSONResponse: access_token = extract_access_token(request.headers) if access_token: - session = session_manager.get_session(access_token) + session_id = session_manager.find_session_id(access_token) + + if session_id: + session_manager.clear_session(session_id) - if session: - session_manager.clear_session(access_token) return JSONResponse( { 'message': translator.get('ok', __package__) @@ -95,13 +99,13 @@ def create_session_guard(app : ASGIApp) -> ASGIApp: async def middleware(scope : Scope, receive : Receive, send : Send) -> None: access_token = extract_access_token(Headers(scope = scope)) - if access_token and session_manager.validate_session(access_token): - return await app(scope, receive, send) - if access_token: - session = session_manager.get_session(access_token) + session_id = session_manager.find_session_id(access_token) + + if session_id: + if session_manager.validate_session(session_id): + return await app(scope, receive, send) - if session: response = JSONResponse( { 'message': translator.get('invalid_access_token', __package__) diff --git a/facefusion/session_manager.py b/facefusion/session_manager.py index 8a27abca..06284147 100644 --- a/facefusion/session_manager.py +++ b/facefusion/session_manager.py @@ -3,16 +3,16 @@ from datetime import datetime, timedelta from typing import Dict from typing import Optional -from facefusion.types import Session, Token +from facefusion.types import Session, SessionId -SESSIONS : Dict[Token, Session] = {} +SESSIONS : Dict[SessionId, Session] = {} def create_session() -> Session: session : Session =\ { - 'access_token': secrets.token_urlsafe(128), - 'refresh_token': secrets.token_urlsafe(128), + 'access_token': secrets.token_urlsafe(64), + 'refresh_token': secrets.token_urlsafe(64), 'created_at': datetime.now(), 'expires_at': datetime.now() + timedelta(minutes = 10) } @@ -20,20 +20,27 @@ def create_session() -> Session: return session -def get_session(access_token : Token) -> Optional[Session]: - return SESSIONS.get(access_token) +def get_session(session_id : SessionId) -> Optional[Session]: + return SESSIONS.get(session_id) -def set_session(access_token : Token, session : Session) -> None: - SESSIONS[access_token] = session +def find_session_id(access_token : str) -> Optional[SessionId]: + for session_id, session in SESSIONS.items(): + if session.get('access_token') == access_token: + return session_id + return None -def validate_session(access_token : Token) -> bool: - session = get_session(access_token) +def set_session(session_id : SessionId, session : Session) -> None: + SESSIONS[session_id] = session + + +def validate_session(session_id : SessionId) -> bool: + session = get_session(session_id) return session and datetime.now() <= session.get('expires_at') -def clear_session(access_token : Token) -> None: - if access_token in SESSIONS: - del SESSIONS[access_token] +def clear_session(session_id : SessionId) -> None: + if session_id in SESSIONS: + del SESSIONS[session_id] diff --git a/facefusion/types.py b/facefusion/types.py index d55be4ff..61948df6 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -103,6 +103,7 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool] Content : TypeAlias = Dict[str, Any] Token : TypeAlias = str +SessionId : TypeAlias = str Session = TypedDict('Session', { 'access_token': Token, diff --git a/tests/test_api_session.py b/tests/test_api_session.py index 29f83eb9..e9526635 100644 --- a/tests/test_api_session.py +++ b/tests/test_api_session.py @@ -28,7 +28,8 @@ def test_create_session(test_client : TestClient) -> None: }) create_session_body = create_session_response.json() - assert session_manager.get_session(create_session_body.get('access_token')) + assert create_session_body.get('access_token') + assert create_session_body.get('refresh_token') assert create_session_response.status_code == 201 create_session_response = test_client.post('/session', json = @@ -78,9 +79,9 @@ def test_get_session(test_client : TestClient) -> None: assert get_session_response.status_code == 200 - access_token = create_session_body.get('access_token') - session : Session = session_manager.get_session(access_token) - session_manager.set_session(access_token, + session_id = session_manager.find_session_id(create_session_body.get('access_token')) + session : Session = session_manager.get_session(session_id) + session_manager.set_session(session_id, { 'access_token': session.get('access_token'), 'refresh_token': session.get('refresh_token'), @@ -90,7 +91,7 @@ def test_get_session(test_client : TestClient) -> None: get_session_response = test_client.get('/session', headers = { - 'Authorization': 'Bearer ' + access_token + 'Authorization': 'Bearer ' + create_session_body.get('access_token') }) assert get_session_response.status_code == 426 @@ -116,9 +117,9 @@ def test_refresh_session(test_client : TestClient) -> None: }) refresh_session_body = refresh_session_response.json() - assert session_manager.get_session(create_session_body.get('access_token')) is None - - assert session_manager.get_session(refresh_session_body.get('access_token')) + assert refresh_session_body.get('access_token') + assert refresh_session_body.get('refresh_token') + assert not refresh_session_body.get('access_token') == create_session_body.get('access_token') assert refresh_session_response.status_code == 200 @@ -149,6 +150,6 @@ def test_destroy_session(test_client : TestClient) -> None: 'Authorization': 'Bearer ' + create_session_body.get('access_token') }) - assert session_manager.get_session(create_session_body.get('access_token')) is None + assert session_manager.find_session_id(create_session_body.get('access_token')) is None assert delete_session_response.status_code == 200 diff --git a/tests/test_session_manager.py b/tests/test_session_manager.py index c886c541..59e94938 100644 --- a/tests/test_session_manager.py +++ b/tests/test_session_manager.py @@ -6,22 +6,22 @@ from facefusion.session_manager import clear_session, create_session, get_sessio def test_get_and_set_session() -> None: session = create_session() - access_token = secrets.token_urlsafe(128) + session_id = secrets.token_urlsafe(16) - set_session(access_token, session) + set_session(session_id, session) - assert get_session(access_token) == session + assert get_session(session_id) == session def test_validate_session() -> None: session = create_session() - access_token = secrets.token_urlsafe(128) + session_id = secrets.token_urlsafe(16) - set_session(access_token, session) + set_session(session_id, session) - assert validate_session(access_token) is True + assert validate_session(session_id) is True - set_session(access_token, + set_session(session_id, { 'access_token': session.get('access_token'), 'refresh_token': session.get('refresh_token'), @@ -29,17 +29,17 @@ def test_validate_session() -> None: 'expires_at': session.get('expires_at') - timedelta(hours = 1) }) - assert validate_session(access_token) is False + assert validate_session(session_id) is False def test_clear_session() -> None: session = create_session() - access_token = secrets.token_urlsafe(128) + session_id = secrets.token_urlsafe(16) - set_session(access_token, session) + set_session(session_id, session) - assert validate_session(access_token) is True + assert validate_session(session_id) is True - clear_session(access_token) + clear_session(session_id) - assert validate_session(access_token) is None + assert validate_session(session_id) is None