Allow to specify different API keys for the guardrailing service (#36)

* minor refactor for getting invariant api keys for guardrailing

* allow different guardrailing api key

* tests

* fix comment + import

* improved unauthorized handling
This commit is contained in:
Luca Beurer-Kellner
2025-04-03 12:15:30 +02:00
committed by GitHub
parent e17b53b927
commit c4dd3f3b19
9 changed files with 187 additions and 24 deletions

View File

@@ -4,9 +4,22 @@ from typing import Tuple, Optional
from fastapi import HTTPException, Request
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER = "invariant-guardrails-authorization"
API_KEYS_SEPARATOR = ";invariant-auth="
def extract_guardrail_service_authorization_from_headers(
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
"""
Extracts the optional Invariant-Guardrails-Authorization authorization header from the request.
This header can be specifified to use a different API key for guardrailing compared to
Explorer interactions.
"""
return request.headers.get(INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER)
def extract_authorization_from_headers(
request: Request,
dataset_name: Optional[str] = None,

View File

@@ -3,8 +3,13 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import fastapi
from common.config_manager import GatewayConfig
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from common.authorization import (
extract_guardrail_service_authorization_from_headers,
)
@dataclass(frozen=True)
@@ -13,7 +18,10 @@ class RequestContext:
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
# authorization to use for invariant service like explorer
invariant_authorization: Optional[str] = None
# authorization to use for invariant guardrailing specifically
guardrail_authorization: Optional[str] = None
# the set of guardrails to enforce for this request
guardrails: Optional[GuardrailRuleSet] = None
config: Dict[str, Any] = None
@@ -36,6 +44,7 @@ class RequestContext:
invariant_authorization: Optional[str] = None,
guardrails: Optional[GuardrailRuleSet] = None,
config: Optional[GatewayConfig] = None,
request: fastapi.Request = None,
) -> "RequestContext":
"""Creates a new RequestContext instance, applying default guardrails if needed."""
@@ -71,15 +80,36 @@ class RequestContext:
logging_guardrails=[],
)
# if additionally provided, extract separate API key to use with guardrailing service
guardrail_service_authorization = None
if (
guardrail_authorization
:= extract_guardrail_service_authorization_from_headers(request)
):
guardrail_service_authorization = guardrail_authorization
return cls(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrail_authorization=guardrail_service_authorization,
guardrails=guardrails,
config=context_config,
_created_via_factory=True,
)
def get_guardrailing_authorization(self) -> Optional[str]:
"""
Returns the authorization to use for the guardrailing service.
This can be different from the invariant authorization, but falls back
"to be the same if not explicitly set via header.
See also extract_guardrail_service_authorization_from_headers(...)
"""
return self.guardrail_authorization or self.invariant_authorization
def __repr__(self) -> str:
return (
f"RequestContext("

View File

@@ -3,6 +3,8 @@
import os
from typing import Any, Dict, List
from fastapi import HTTPException
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
@@ -62,6 +64,11 @@ def create_annotations_from_guardrails_errors(
return annotations
def get_explorer_api_url() -> str:
return "https://preview-explorer.invariantlabs.ai"
return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
async def push_trace(
messages: List[List[Dict[str, Any]]],
dataset_name: str,
@@ -94,7 +101,7 @@ async def push_trace(
metadata=metadata,
)
client = AsyncClient(
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
api_url=get_explorer_api_url().rstrip("/"),
api_key=invariant_authorization.split("Bearer ")[1],
)
try:
@@ -117,7 +124,7 @@ async def fetch_guardrails_from_explorer(
# dataset details without requiring a username.
client = httpx.AsyncClient(
base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
base_url=get_explorer_api_url().rstrip("/"),
headers={
"Authorization": invariant_authorization,
},
@@ -125,7 +132,12 @@ async def fetch_guardrails_from_explorer(
# Get the user details.
user_info_response = await client.get("/api/v1/user/identity")
if user_info_response.status_code != 200:
if user_info_response.status_code == 401:
raise HTTPException(
status_code=401,
detail="Invalid Invariant API key. Please check your API key.",
)
elif user_info_response.status_code != 200:
raise ValueError(
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
)

View File

@@ -6,9 +6,13 @@ import time
from typing import Any, Dict, List
from functools import wraps
from fastapi import HTTPException
import httpx
from common.guardrails import Guardrail
from common.request_context import RequestContext
from common.authorization import (
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
)
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -96,12 +100,17 @@ async def preload_guardrails(context: "RequestContext") -> None:
# Move these calls to a batch preload/validate API.
for blocking_guardrail in context.guardrails.blocking_guardrails:
task = asyncio.create_task(
_preload(blocking_guardrail.content, context.invariant_authorization)
_preload(
blocking_guardrail.content, context.get_guardrailing_authorization()
)
)
asyncio.shield(task)
for logging_guadrail in context.guardrails.logging_guardrails:
task = asyncio.create_task(
_preload(logging_guadrail.content, context.invariant_authorization)
_preload(
logging_guadrail.content,
context.get_guardrailing_authorization(),
)
)
asyncio.shield(task)
except Exception as e:
@@ -332,7 +341,7 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
async def check_guardrails(
messages: List[Dict[str, Any]],
guardrails: List[Guardrail],
invariant_authorization: str,
context: RequestContext,
) -> Dict[str, Any]:
"""
Checks guardrails on the list of messages.
@@ -357,11 +366,18 @@ async def check_guardrails(
"policies": [g.content for g in guardrails],
},
headers={
"Authorization": invariant_authorization,
"Authorization": context.get_guardrailing_authorization(),
"Accept": "application/json",
},
)
if not result.is_success:
if result.status_code == 401:
raise HTTPException(
status_code=401,
detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format(
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER
),
)
raise Exception(
f"Guardrails check failed: {result.status_code} - {result.text}"
)
@@ -392,6 +408,8 @@ async def check_guardrails(
]
}
return aggregated_errors
except HTTPException as e:
raise e
except Exception as e:
print(f"Failed to verify guardrails: {e}")
# make sure runtime errors are also visible in e.g. Explorer

View File

@@ -105,6 +105,7 @@ async def anthropic_v1_messages_gateway(
invariant_authorization=invariant_authorization,
guardrails=header_guardrails or dataset_guardrails,
config=config,
request=request,
)
if request_json.get("stream"):
return await handle_streaming_response(context, client, anthropic_request)
@@ -157,7 +158,7 @@ async def get_guardrails_check_result(
guardrails_execution_result = await check_guardrails(
messages=converted_messages,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
context=context,
)
return guardrails_execution_result

View File

@@ -64,11 +64,16 @@ async def gemini_generate_content_gateway(
status_code=400,
)
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS + [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
k: v
for k, v in request.headers.items()
if k.lower() not in IGNORED_HEADERS + [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
}
headers["accept-encoding"] = "identity"
invariant_authorization, gemini_api_key = extract_authorization_from_headers(
request, dataset_name, GEMINI_AUTHORIZATION_HEADER, [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
request,
dataset_name,
GEMINI_AUTHORIZATION_HEADER,
[GEMINI_AUTHORIZATION_FALLBACK_HEADER],
)
headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key
@@ -98,6 +103,7 @@ async def gemini_generate_content_gateway(
invariant_authorization=invariant_authorization,
guardrails=header_guardrails or dataset_guardrails,
config=config,
request=request,
)
if alt == "sse" or endpoint == "streamGenerateContent":
return await stream_response(
@@ -394,7 +400,7 @@ async def get_guardrails_check_result(
guardrails_execution_result = await check_guardrails(
messages=converted_requests + converted_responses,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
context=context,
)
return guardrails_execution_result

View File

@@ -149,6 +149,7 @@ async def openai_chat_completions_gateway(
invariant_authorization=invariant_authorization,
guardrails=header_guardrails or dataset_guardrails,
config=config,
request=request,
)
if request_json.get("stream", False):
return await handle_stream_response(
@@ -546,7 +547,7 @@ async def get_guardrails_check_result(
guardrails_execution_result = await check_guardrails(
messages=messages,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
context=context,
)
return guardrails_execution_result

View File

@@ -18,7 +18,7 @@ openai_client = OpenAI(
"Invariant-Authorization": "Bearer " + os.getenv("INVARIANT_API_KEY"),
"Invariant-Guardrails": guardrails,
},
base_url="http://localhost:9999/api/v1/gateway/non-streaming/openai",
base_url="http://localhost:8005/api/v1/gateway/non-streaming/openai",
)
response = openai_client.chat.completions.create(

View File

@@ -7,6 +7,9 @@ import random
import string
import pytest
from gateway.common.config_manager import GatewayConfig
from gateway.common.request_context import RequestContext
# Add root folder (parent) to sys.path
sys.path.append(
@@ -15,11 +18,19 @@ sys.path.append(
)
)
from gateway.common.authorization import extract_authorization_from_headers, INVARIANT_AUTHORIZATION_HEADER, API_KEYS_SEPARATOR
from gateway.common.authorization import (
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
extract_authorization_from_headers,
INVARIANT_AUTHORIZATION_HEADER,
API_KEYS_SEPARATOR,
)
@pytest.mark.parametrize("push_to_explorer", [True, False])
@pytest.mark.parametrize("invariant_authorization", [True, False])
@pytest.mark.parametrize("invariant_authorization_appended_to_llm_provider_api_key", [True, False])
@pytest.mark.parametrize(
"invariant_authorization_appended_to_llm_provider_api_key", [True, False]
)
@pytest.mark.parametrize("use_fallback_header", [True, False])
def test_extract_authorization_from_headers(
push_to_explorer: bool,
@@ -29,19 +40,23 @@ def test_extract_authorization_from_headers(
):
"""Test the extract_authorization_from_headers function."""
llm_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
inv_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
llm_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
inv_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
dataset_name = "test-dataset" if push_to_explorer else None
headers: dict[str, str] = {}
headers: dict[str, str] = {}
llm_provider_api_key = "fallback-header" if use_fallback_header else "llm-provider-api-key"
llm_provider_api_key = (
"fallback-header" if use_fallback_header else "llm-provider-api-key"
)
headers[llm_provider_api_key] = llm_apikey
if invariant_authorization:
print("invariant_authorization - TRUE")
if invariant_authorization_appended_to_llm_provider_api_key:
headers[llm_provider_api_key] = f"{headers.get(llm_provider_api_key, '')}{API_KEYS_SEPARATOR}{inv_apikey}"
headers[llm_provider_api_key] = (
f"{headers.get(llm_provider_api_key, '')}{API_KEYS_SEPARATOR}{inv_apikey}"
)
else:
headers[INVARIANT_AUTHORIZATION_HEADER] = f"Bearer {inv_apikey}"
@@ -50,7 +65,7 @@ def test_extract_authorization_from_headers(
def __init__(self, headers):
self.headers = headers
request = MockRequest(headers)
request = MockRequest(headers)
# Call the function
try:
@@ -70,12 +85,79 @@ def test_extract_authorization_from_headers(
raise e
# Verify the results
if invariant_authorization:
if not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key:
if (
not push_to_explorer
and invariant_authorization_appended_to_llm_provider_api_key
):
assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[0] == llm_apikey
assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[1] == inv_apikey
else:
assert invariant_auth == ("Bearer " + inv_apikey)
else:
assert invariant_auth is None
if not(not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key):
assert llm_provider_api_key == llm_apikey
if not (
not push_to_explorer
and invariant_authorization_appended_to_llm_provider_api_key
):
assert llm_provider_api_key == llm_apikey
@pytest.mark.parametrize("use_guardrailing_api_key", [True, False])
def test_extract_guardrails_authorization_from_headers(use_guardrailing_api_key: bool):
headers: dict[str, str] = {}
inv_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
inv_guardrails_apikey = "".join(
random.choices(string.ascii_letters + string.digits, k=10)
)
llm_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
headers[INVARIANT_AUTHORIZATION_HEADER] = f"Bearer {inv_apikey}"
headers["Authorization"] = f"Bearer {llm_apikey}"
if use_guardrailing_api_key:
headers[INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER] = (
f"Bearer {inv_guardrails_apikey}"
)
class MockRequest:
def __init__(self, headers):
self.headers = headers
dataset_name = "test-dataset"
request = MockRequest(headers)
try:
invariant_authorization, llm_provider_api_key = (
extract_authorization_from_headers(
request,
dataset_name=dataset_name,
llm_provider_api_key_header="Authorization",
)
)
context = RequestContext.create(
request_json={"input": "test"},
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrails=None,
config=GatewayConfig(),
request=request,
)
except HTTPException as e:
# If an exception is raised, check if it is the expected one
if not invariant_authorization:
assert e.status_code == 400
assert e.detail == "Missing invariant api key"
return
else:
raise e
# Verify the results
assert invariant_authorization == ("Bearer " + inv_apikey)
assert llm_provider_api_key == llm_apikey
assert context.get_guardrailing_authorization() == (
"Bearer " + inv_guardrails_apikey
if use_guardrailing_api_key
else invariant_authorization
)