From 37005c03be5e48892b63bd127cfe2ef764ca4290 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Wed, 2 Apr 2025 16:03:51 +0200 Subject: [PATCH] fix: add tests (and found bug) --- gateway/common/authorization.py | 7 +- tests/unit_tests/common/test_authorization.py | 81 +++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/common/test_authorization.py diff --git a/gateway/common/authorization.py b/gateway/common/authorization.py index aa7fffb..74ba70b 100644 --- a/gateway/common/authorization.py +++ b/gateway/common/authorization.py @@ -40,10 +40,9 @@ def extract_authorization_from_headers( for header in llm_provider_fallback_api_key_headers: llm_provider_api_key = request.headers.get(header) if llm_provider_api_key: + llm_provider_api_key_header = header break - if "Bearer " in llm_provider_api_key: - llm_provider_api_key = llm_provider_api_key.split("Bearer ")[1].strip() if dataset_name: if invariant_authorization is None: @@ -66,4 +65,8 @@ def extract_authorization_from_headers( invariant_authorization = f"Bearer {api_keys[1].strip()}" llm_provider_api_key = f"{api_keys[0].strip()}" + + if llm_provider_api_key and "Bearer " in llm_provider_api_key: + llm_provider_api_key = llm_provider_api_key.split("Bearer ")[1].strip() + return invariant_authorization, llm_provider_api_key diff --git a/tests/unit_tests/common/test_authorization.py b/tests/unit_tests/common/test_authorization.py new file mode 100644 index 0000000..3742f35 --- /dev/null +++ b/tests/unit_tests/common/test_authorization.py @@ -0,0 +1,81 @@ +"""Tests for the authorization header extractor.""" + +import os +import sys +from fastapi import HTTPException +import random +import string +import pytest + + +# Add root folder (parent) to sys.path +sys.path.append( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) +) + +from gateway.common.authorization import extract_authorization_from_headers, INVARIANT_AUTHORIZATION_HEADER, API_KEYS_SEPARATOR + +@pytest.mark.parametrize("push_to_explorer", [True, False]) +@pytest.mark.parametrize("invariant_authorization", [True, False]) +@pytest.mark.parametrize("invariant_authorization_appended_to_llm_provider_api_key", [True, False]) +@pytest.mark.parametrize("use_fallback_header", [True, False]) +def test_extract_authorization_from_headers( + push_to_explorer: bool, + invariant_authorization: bool, + invariant_authorization_appended_to_llm_provider_api_key: bool, + use_fallback_header: bool, +): + """Test the extract_authorization_from_headers function.""" + + llm_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + inv_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + dataset_name = "test-dataset" if push_to_explorer else None + + headers: dict[str, str] = {} + + llm_provider_api_key = "fallback-header" if use_fallback_header else "llm-provider-api-key" + headers[llm_provider_api_key] = llm_apikey + + if invariant_authorization: + print("invariant_authorization - TRUE") + if invariant_authorization_appended_to_llm_provider_api_key: + headers[llm_provider_api_key] = f"{headers.get(llm_provider_api_key, '')}{API_KEYS_SEPARATOR}{inv_apikey}" + else: + headers[INVARIANT_AUTHORIZATION_HEADER] = f"Bearer {inv_apikey}" + + # Mock request headers + class MockRequest: + def __init__(self, headers): + self.headers = headers + + request = MockRequest(headers) + + # Call the function + try: + invariant_auth, llm_provider_api_key = extract_authorization_from_headers( + request, + dataset_name=dataset_name, + llm_provider_api_key_header="llm-provider-api-key", + llm_provider_fallback_api_key_headers=["fallback-header"], + ) + except HTTPException as e: + # If an exception is raised, check if it is the expected one + if not invariant_authorization: + assert e.status_code == 400 + assert e.detail == "Missing invariant api key" + return + else: + raise e + # Verify the results + if invariant_authorization: + if not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key: + assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[0] == llm_apikey + assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[1] == inv_apikey + else: + assert invariant_auth == ("Bearer " + inv_apikey) + else: + assert invariant_auth is None + if not(not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key): + assert llm_provider_api_key == llm_apikey \ No newline at end of file