mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 07:19:42 +02:00
fix: add tests (and found bug)
This commit is contained in:
@@ -40,10 +40,9 @@ def extract_authorization_from_headers(
|
||||
for header in llm_provider_fallback_api_key_headers:
|
||||
llm_provider_api_key = request.headers.get(header)
|
||||
if llm_provider_api_key:
|
||||
llm_provider_api_key_header = header
|
||||
break
|
||||
|
||||
if "Bearer " in llm_provider_api_key:
|
||||
llm_provider_api_key = llm_provider_api_key.split("Bearer ")[1].strip()
|
||||
|
||||
if dataset_name:
|
||||
if invariant_authorization is None:
|
||||
@@ -66,4 +65,8 @@ def extract_authorization_from_headers(
|
||||
|
||||
invariant_authorization = f"Bearer {api_keys[1].strip()}"
|
||||
llm_provider_api_key = f"{api_keys[0].strip()}"
|
||||
|
||||
if llm_provider_api_key and "Bearer " in llm_provider_api_key:
|
||||
llm_provider_api_key = llm_provider_api_key.split("Bearer ")[1].strip()
|
||||
|
||||
return invariant_authorization, llm_provider_api_key
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Tests for the authorization header extractor."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from fastapi import HTTPException
|
||||
import random
|
||||
import string
|
||||
import pytest
|
||||
|
||||
|
||||
# Add root folder (parent) to sys.path
|
||||
sys.path.append(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
)
|
||||
|
||||
from gateway.common.authorization import extract_authorization_from_headers, INVARIANT_AUTHORIZATION_HEADER, API_KEYS_SEPARATOR
|
||||
|
||||
@pytest.mark.parametrize("push_to_explorer", [True, False])
|
||||
@pytest.mark.parametrize("invariant_authorization", [True, False])
|
||||
@pytest.mark.parametrize("invariant_authorization_appended_to_llm_provider_api_key", [True, False])
|
||||
@pytest.mark.parametrize("use_fallback_header", [True, False])
|
||||
def test_extract_authorization_from_headers(
|
||||
push_to_explorer: bool,
|
||||
invariant_authorization: bool,
|
||||
invariant_authorization_appended_to_llm_provider_api_key: bool,
|
||||
use_fallback_header: bool,
|
||||
):
|
||||
"""Test the extract_authorization_from_headers function."""
|
||||
|
||||
llm_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
inv_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
dataset_name = "test-dataset" if push_to_explorer else None
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
llm_provider_api_key = "fallback-header" if use_fallback_header else "llm-provider-api-key"
|
||||
headers[llm_provider_api_key] = llm_apikey
|
||||
|
||||
if invariant_authorization:
|
||||
print("invariant_authorization - TRUE")
|
||||
if invariant_authorization_appended_to_llm_provider_api_key:
|
||||
headers[llm_provider_api_key] = f"{headers.get(llm_provider_api_key, '')}{API_KEYS_SEPARATOR}{inv_apikey}"
|
||||
else:
|
||||
headers[INVARIANT_AUTHORIZATION_HEADER] = f"Bearer {inv_apikey}"
|
||||
|
||||
# Mock request headers
|
||||
class MockRequest:
|
||||
def __init__(self, headers):
|
||||
self.headers = headers
|
||||
|
||||
request = MockRequest(headers)
|
||||
|
||||
# Call the function
|
||||
try:
|
||||
invariant_auth, llm_provider_api_key = extract_authorization_from_headers(
|
||||
request,
|
||||
dataset_name=dataset_name,
|
||||
llm_provider_api_key_header="llm-provider-api-key",
|
||||
llm_provider_fallback_api_key_headers=["fallback-header"],
|
||||
)
|
||||
except HTTPException as e:
|
||||
# If an exception is raised, check if it is the expected one
|
||||
if not invariant_authorization:
|
||||
assert e.status_code == 400
|
||||
assert e.detail == "Missing invariant api key"
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
# Verify the results
|
||||
if invariant_authorization:
|
||||
if not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key:
|
||||
assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[0] == llm_apikey
|
||||
assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[1] == inv_apikey
|
||||
else:
|
||||
assert invariant_auth == ("Bearer " + inv_apikey)
|
||||
else:
|
||||
assert invariant_auth is None
|
||||
if not(not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key):
|
||||
assert llm_provider_api_key == llm_apikey
|
||||
Reference in New Issue
Block a user