mirror of
https://github.com/facefusion/facefusion.git
synced 2026-04-22 17:36:16 +02:00
Add session_id, make token size more reasonable (#983)
* Add session_id, make token size more reasonable * Use more direct approach
This commit is contained in:
+21
-17
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
from starlette.datastructures import Headers
|
||||
@@ -15,8 +16,9 @@ async def create_session(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
|
||||
if not body.get('api_key') or body.get('api_key') == os.getenv('FACEFUSION_API_KEY'):
|
||||
session_id = secrets.token_urlsafe(16)
|
||||
session = session_manager.create_session()
|
||||
session_manager.set_session(session.get('access_token'), session)
|
||||
session_manager.set_session(session_id, session)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
@@ -34,9 +36,11 @@ async def get_session(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.headers)
|
||||
|
||||
if access_token:
|
||||
session = session_manager.get_session(access_token)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if session_id:
|
||||
session = session_manager.get_session(session_id)
|
||||
|
||||
if session:
|
||||
return JSONResponse(
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
@@ -54,16 +58,15 @@ async def get_session(request : Request) -> JSONResponse:
|
||||
async def refresh_session(request : Request) -> JSONResponse:
|
||||
body = await request.json()
|
||||
|
||||
for access_token, session in session_manager.SESSIONS.items():
|
||||
for session_id, session in session_manager.SESSIONS.items():
|
||||
if session.get('refresh_token') == body.get('refresh_token'):
|
||||
session_manager.clear_session(access_token)
|
||||
session = session_manager.create_session()
|
||||
session_manager.set_session(session.get('access_token'), session)
|
||||
__session__ = session_manager.create_session()
|
||||
session_manager.set_session(session_id, __session__)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token')
|
||||
'access_token': __session__.get('access_token'),
|
||||
'refresh_token': __session__.get('refresh_token')
|
||||
}, status_code = HTTP_200_OK)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -76,10 +79,11 @@ async def destroy_session(request : Request) -> JSONResponse:
|
||||
access_token = extract_access_token(request.headers)
|
||||
|
||||
if access_token:
|
||||
session = session_manager.get_session(access_token)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if session_id:
|
||||
session_manager.clear_session(session_id)
|
||||
|
||||
if session:
|
||||
session_manager.clear_session(access_token)
|
||||
return JSONResponse(
|
||||
{
|
||||
'message': translator.get('ok', __package__)
|
||||
@@ -95,13 +99,13 @@ def create_session_guard(app : ASGIApp) -> ASGIApp:
|
||||
async def middleware(scope : Scope, receive : Receive, send : Send) -> None:
|
||||
access_token = extract_access_token(Headers(scope = scope))
|
||||
|
||||
if access_token and session_manager.validate_session(access_token):
|
||||
return await app(scope, receive, send)
|
||||
|
||||
if access_token:
|
||||
session = session_manager.get_session(access_token)
|
||||
session_id = session_manager.find_session_id(access_token)
|
||||
|
||||
if session_id:
|
||||
if session_manager.validate_session(session_id):
|
||||
return await app(scope, receive, send)
|
||||
|
||||
if session:
|
||||
response = JSONResponse(
|
||||
{
|
||||
'message': translator.get('invalid_access_token', __package__)
|
||||
|
||||
@@ -3,16 +3,16 @@ from datetime import datetime, timedelta
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from facefusion.types import Session, Token
|
||||
from facefusion.types import Session, SessionId
|
||||
|
||||
SESSIONS : Dict[Token, Session] = {}
|
||||
SESSIONS : Dict[SessionId, Session] = {}
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
session : Session =\
|
||||
{
|
||||
'access_token': secrets.token_urlsafe(128),
|
||||
'refresh_token': secrets.token_urlsafe(128),
|
||||
'access_token': secrets.token_urlsafe(64),
|
||||
'refresh_token': secrets.token_urlsafe(64),
|
||||
'created_at': datetime.now(),
|
||||
'expires_at': datetime.now() + timedelta(minutes = 10)
|
||||
}
|
||||
@@ -20,20 +20,27 @@ def create_session() -> Session:
|
||||
return session
|
||||
|
||||
|
||||
def get_session(access_token : Token) -> Optional[Session]:
|
||||
return SESSIONS.get(access_token)
|
||||
def get_session(session_id : SessionId) -> Optional[Session]:
|
||||
return SESSIONS.get(session_id)
|
||||
|
||||
|
||||
def set_session(access_token : Token, session : Session) -> None:
|
||||
SESSIONS[access_token] = session
|
||||
def find_session_id(access_token : str) -> Optional[SessionId]:
|
||||
for session_id, session in SESSIONS.items():
|
||||
if session.get('access_token') == access_token:
|
||||
return session_id
|
||||
return None
|
||||
|
||||
|
||||
def validate_session(access_token : Token) -> bool:
|
||||
session = get_session(access_token)
|
||||
def set_session(session_id : SessionId, session : Session) -> None:
|
||||
SESSIONS[session_id] = session
|
||||
|
||||
|
||||
def validate_session(session_id : SessionId) -> bool:
|
||||
session = get_session(session_id)
|
||||
return session and datetime.now() <= session.get('expires_at')
|
||||
|
||||
|
||||
def clear_session(access_token : Token) -> None:
|
||||
if access_token in SESSIONS:
|
||||
del SESSIONS[access_token]
|
||||
def clear_session(session_id : SessionId) -> None:
|
||||
if session_id in SESSIONS:
|
||||
del SESSIONS[session_id]
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ ProcessStep : TypeAlias = Callable[[str, int, Args], bool]
|
||||
Content : TypeAlias = Dict[str, Any]
|
||||
|
||||
Token : TypeAlias = str
|
||||
SessionId : TypeAlias = str
|
||||
Session = TypedDict('Session',
|
||||
{
|
||||
'access_token': Token,
|
||||
|
||||
@@ -28,7 +28,8 @@ def test_create_session(test_client : TestClient) -> None:
|
||||
})
|
||||
create_session_body = create_session_response.json()
|
||||
|
||||
assert session_manager.get_session(create_session_body.get('access_token'))
|
||||
assert create_session_body.get('access_token')
|
||||
assert create_session_body.get('refresh_token')
|
||||
assert create_session_response.status_code == 201
|
||||
|
||||
create_session_response = test_client.post('/session', json =
|
||||
@@ -78,9 +79,9 @@ def test_get_session(test_client : TestClient) -> None:
|
||||
|
||||
assert get_session_response.status_code == 200
|
||||
|
||||
access_token = create_session_body.get('access_token')
|
||||
session : Session = session_manager.get_session(access_token)
|
||||
session_manager.set_session(access_token,
|
||||
session_id = session_manager.find_session_id(create_session_body.get('access_token'))
|
||||
session : Session = session_manager.get_session(session_id)
|
||||
session_manager.set_session(session_id,
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token'),
|
||||
@@ -90,7 +91,7 @@ def test_get_session(test_client : TestClient) -> None:
|
||||
|
||||
get_session_response = test_client.get('/session', headers =
|
||||
{
|
||||
'Authorization': 'Bearer ' + access_token
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
|
||||
assert get_session_response.status_code == 426
|
||||
@@ -116,9 +117,9 @@ def test_refresh_session(test_client : TestClient) -> None:
|
||||
})
|
||||
refresh_session_body = refresh_session_response.json()
|
||||
|
||||
assert session_manager.get_session(create_session_body.get('access_token')) is None
|
||||
|
||||
assert session_manager.get_session(refresh_session_body.get('access_token'))
|
||||
assert refresh_session_body.get('access_token')
|
||||
assert refresh_session_body.get('refresh_token')
|
||||
assert not refresh_session_body.get('access_token') == create_session_body.get('access_token')
|
||||
|
||||
assert refresh_session_response.status_code == 200
|
||||
|
||||
@@ -149,6 +150,6 @@ def test_destroy_session(test_client : TestClient) -> None:
|
||||
'Authorization': 'Bearer ' + create_session_body.get('access_token')
|
||||
})
|
||||
|
||||
assert session_manager.get_session(create_session_body.get('access_token')) is None
|
||||
assert session_manager.find_session_id(create_session_body.get('access_token')) is None
|
||||
|
||||
assert delete_session_response.status_code == 200
|
||||
|
||||
@@ -6,22 +6,22 @@ from facefusion.session_manager import clear_session, create_session, get_sessio
|
||||
|
||||
def test_get_and_set_session() -> None:
|
||||
session = create_session()
|
||||
access_token = secrets.token_urlsafe(128)
|
||||
session_id = secrets.token_urlsafe(16)
|
||||
|
||||
set_session(access_token, session)
|
||||
set_session(session_id, session)
|
||||
|
||||
assert get_session(access_token) == session
|
||||
assert get_session(session_id) == session
|
||||
|
||||
|
||||
def test_validate_session() -> None:
|
||||
session = create_session()
|
||||
access_token = secrets.token_urlsafe(128)
|
||||
session_id = secrets.token_urlsafe(16)
|
||||
|
||||
set_session(access_token, session)
|
||||
set_session(session_id, session)
|
||||
|
||||
assert validate_session(access_token) is True
|
||||
assert validate_session(session_id) is True
|
||||
|
||||
set_session(access_token,
|
||||
set_session(session_id,
|
||||
{
|
||||
'access_token': session.get('access_token'),
|
||||
'refresh_token': session.get('refresh_token'),
|
||||
@@ -29,17 +29,17 @@ def test_validate_session() -> None:
|
||||
'expires_at': session.get('expires_at') - timedelta(hours = 1)
|
||||
})
|
||||
|
||||
assert validate_session(access_token) is False
|
||||
assert validate_session(session_id) is False
|
||||
|
||||
|
||||
def test_clear_session() -> None:
|
||||
session = create_session()
|
||||
access_token = secrets.token_urlsafe(128)
|
||||
session_id = secrets.token_urlsafe(16)
|
||||
|
||||
set_session(access_token, session)
|
||||
set_session(session_id, session)
|
||||
|
||||
assert validate_session(access_token) is True
|
||||
assert validate_session(session_id) is True
|
||||
|
||||
clear_session(access_token)
|
||||
clear_session(session_id)
|
||||
|
||||
assert validate_session(access_token) is None
|
||||
assert validate_session(session_id) is None
|
||||
|
||||
Reference in New Issue
Block a user