Add session_id, make token size more reasonable (#983)

* Add session_id, make token size more reasonable

* Use more direct approach
This commit is contained in:
Henry Ruhs
2025-11-17 11:09:47 +01:00
committed by henryruhs
parent 04f6a2c758
commit dd772279ed
5 changed files with 65 additions and 52 deletions
+21 -17
View File
@@ -1,4 +1,5 @@
import os
import secrets
from typing import Optional
from starlette.datastructures import Headers
@@ -15,8 +16,9 @@ async def create_session(request : Request) -> JSONResponse:
body = await request.json()
if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'):
session_id = secrets.token_urlsafe(16)
session = session_manager.create_session()
session_manager.set_session(session.get('access_token'), session)
session_manager.set_session(session_id, session)
return JSONResponse(
{
@@ -34,9 +36,11 @@ async def get_session(request : Request) -> JSONResponse:
access_token = extract_access_token(request.headers)
if access_token:
session = session_manager.get_session(access_token)
session_id = session_manager.find_session_id(access_token)
if session_id:
session = session_manager.get_session(session_id)
if session:
return JSONResponse(
{
'access_token': session.get('access_token'),
@@ -54,16 +58,15 @@ async def get_session(request : Request) -> JSONResponse:
async def refresh_session(request : Request) -> JSONResponse:
body = await request.json()
for access_token, session in session_manager.SESSIONS.items():
for session_id, session in session_manager.SESSIONS.items():
if session.get('refresh_token') == body.get('refresh_token'):
session_manager.clear_session(access_token)
session = session_manager.create_session()
session_manager.set_session(session.get('access_token'), session)
__session__ = session_manager.create_session()
session_manager.set_session(session_id, __session__)
return JSONResponse(
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token')
'access_token': __session__.get('access_token'),
'refresh_token': __session__.get('refresh_token')
}, status_code = HTTP_200_OK)
return JSONResponse(
@@ -76,10 +79,11 @@ async def destroy_session(request : Request) -> JSONResponse:
access_token = extract_access_token(request.headers)
if access_token:
session = session_manager.get_session(access_token)
session_id = session_manager.find_session_id(access_token)
if session_id:
session_manager.clear_session(session_id)
if session:
session_manager.clear_session(access_token)
return JSONResponse(
{
'message': translator.get('ok', __package__)
@@ -95,13 +99,13 @@ def create_session_guard(app : ASGIApp) -> ASGIApp:
async def middleware(scope : Scope, receive : Receive, send : Send) -> None:
access_token = extract_access_token(Headers(scope = scope))
if access_token and session_manager.validate_session(access_token):
return await app(scope, receive, send)
if access_token:
session = session_manager.get_session(access_token)
session_id = session_manager.find_session_id(access_token)
if session_id:
if session_manager.validate_session(session_id):
return await app(scope, receive, send)
if session:
response = JSONResponse(
{
'message': translator.get('invalid_access_token', __package__)
+20 -13
View File
@@ -3,16 +3,16 @@ from datetime import datetime, timedelta
from typing import Dict
from typing import Optional
from facefusion.types import Session, Token
from facefusion.types import Session, SessionId
SESSIONS : Dict[Token, Session] = {}
SESSIONS : Dict[SessionId, Session] = {}
def create_session() -> Session:
session : Session =\
{
'access_token': secrets.token_urlsafe(128),
'refresh_token': secrets.token_urlsafe(128),
'access_token': secrets.token_urlsafe(64),
'refresh_token': secrets.token_urlsafe(64),
'created_at': datetime.now(),
'expires_at': datetime.now() + timedelta(minutes = 10)
}
@@ -20,20 +20,27 @@ def create_session() -> Session:
return session
def get_session(access_token : Token) -> Optional[Session]:
return SESSIONS.get(access_token)
def get_session(session_id : SessionId) -> Optional[Session]:
return SESSIONS.get(session_id)
def set_session(access_token : Token, session : Session) -> None:
SESSIONS[access_token] = session
def find_session_id(access_token : str) -> Optional[SessionId]:
for session_id, session in SESSIONS.items():
if session.get('access_token') == access_token:
return session_id
return None
def validate_session(access_token : Token) -> bool:
session = get_session(access_token)
def set_session(session_id : SessionId, session : Session) -> None:
SESSIONS[session_id] = session
def validate_session(session_id : SessionId) -> bool:
session = get_session(session_id)
return session and datetime.now() <= session.get('expires_at')
def clear_session(access_token : Token) -> None:
if access_token in SESSIONS:
del SESSIONS[access_token]
def clear_session(session_id : SessionId) -> None:
if session_id in SESSIONS:
del SESSIONS[session_id]
+1
View File
@@ -103,6 +103,7 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool]
Content : TypeAlias = Dict[str, Any]
Token : TypeAlias = str
SessionId : TypeAlias = str
Session = TypedDict('Session',
{
'access_token': Token,
+10 -9
View File
@@ -28,7 +28,8 @@ def test_create_session(test_client : TestClient) -> None:
})
create_session_body = create_session_response.json()
assert session_manager.get_session(create_session_body.get('access_token'))
assert create_session_body.get('access_token')
assert create_session_body.get('refresh_token')
assert create_session_response.status_code == 201
create_session_response = test_client.post('/session', json =
@@ -78,9 +79,9 @@ def test_get_session(test_client : TestClient) -> None:
assert get_session_response.status_code == 200
access_token = create_session_body.get('access_token')
session : Session = session_manager.get_session(access_token)
session_manager.set_session(access_token,
session_id = session_manager.find_session_id(create_session_body.get('access_token'))
session : Session = session_manager.get_session(session_id)
session_manager.set_session(session_id,
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token'),
@@ -90,7 +91,7 @@ def test_get_session(test_client : TestClient) -> None:
get_session_response = test_client.get('/session', headers =
{
'Authorization': 'Bearer ' + access_token
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
assert get_session_response.status_code == 426
@@ -116,9 +117,9 @@ def test_refresh_session(test_client : TestClient) -> None:
})
refresh_session_body = refresh_session_response.json()
assert session_manager.get_session(create_session_body.get('access_token')) is None
assert session_manager.get_session(refresh_session_body.get('access_token'))
assert refresh_session_body.get('access_token')
assert refresh_session_body.get('refresh_token')
assert not refresh_session_body.get('access_token') == create_session_body.get('access_token')
assert refresh_session_response.status_code == 200
@@ -149,6 +150,6 @@ def test_destroy_session(test_client : TestClient) -> None:
'Authorization': 'Bearer ' + create_session_body.get('access_token')
})
assert session_manager.get_session(create_session_body.get('access_token')) is None
assert session_manager.find_session_id(create_session_body.get('access_token')) is None
assert delete_session_response.status_code == 200
+13 -13
View File
@@ -6,22 +6,22 @@ from facefusion.session_manager import clear_session, create_session, get_sessio
def test_get_and_set_session() -> None:
session = create_session()
access_token = secrets.token_urlsafe(128)
session_id = secrets.token_urlsafe(16)
set_session(access_token, session)
set_session(session_id, session)
assert get_session(access_token) == session
assert get_session(session_id) == session
def test_validate_session() -> None:
session = create_session()
access_token = secrets.token_urlsafe(128)
session_id = secrets.token_urlsafe(16)
set_session(access_token, session)
set_session(session_id, session)
assert validate_session(access_token) is True
assert validate_session(session_id) is True
set_session(access_token,
set_session(session_id,
{
'access_token': session.get('access_token'),
'refresh_token': session.get('refresh_token'),
@@ -29,17 +29,17 @@ def test_validate_session() -> None:
'expires_at': session.get('expires_at') - timedelta(hours = 1)
})
assert validate_session(access_token) is False
assert validate_session(session_id) is False
def test_clear_session() -> None:
session = create_session()
access_token = secrets.token_urlsafe(128)
session_id = secrets.token_urlsafe(16)
set_session(access_token, session)
set_session(session_id, session)
assert validate_session(access_token) is True
assert validate_session(session_id) is True
clear_session(access_token)
clear_session(session_id)
assert validate_session(access_token) is None
assert validate_session(session_id) is None