mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Allow to specify different API keys for the guardrailing service (#36)
* minor refactor for getting invariant api keys for guardrailing * allow different guardrailing api key * tests * fix comment + import * improved unauthorized handling
This commit is contained in:
committed by
GitHub
parent
e17b53b927
commit
c4dd3f3b19
@@ -4,9 +4,22 @@ from typing import Tuple, Optional
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
|
||||
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER = "invariant-guardrails-authorization"
|
||||
API_KEYS_SEPARATOR = ";invariant-auth="
|
||||
|
||||
|
||||
def extract_guardrail_service_authorization_from_headers(
|
||||
request: Request,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extracts the optional Invariant-Guardrails-Authorization authorization header from the request.
|
||||
|
||||
This header can be specifified to use a different API key for guardrailing compared to
|
||||
Explorer interactions.
|
||||
"""
|
||||
return request.headers.get(INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER)
|
||||
|
||||
|
||||
def extract_authorization_from_headers(
|
||||
request: Request,
|
||||
dataset_name: Optional[str] = None,
|
||||
|
||||
@@ -3,8 +3,13 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fastapi
|
||||
|
||||
from common.config_manager import GatewayConfig
|
||||
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
from common.authorization import (
|
||||
extract_guardrail_service_authorization_from_headers,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -13,7 +18,10 @@ class RequestContext:
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
# authorization to use for invariant service like explorer
|
||||
invariant_authorization: Optional[str] = None
|
||||
# authorization to use for invariant guardrailing specifically
|
||||
guardrail_authorization: Optional[str] = None
|
||||
# the set of guardrails to enforce for this request
|
||||
guardrails: Optional[GuardrailRuleSet] = None
|
||||
config: Dict[str, Any] = None
|
||||
@@ -36,6 +44,7 @@ class RequestContext:
|
||||
invariant_authorization: Optional[str] = None,
|
||||
guardrails: Optional[GuardrailRuleSet] = None,
|
||||
config: Optional[GatewayConfig] = None,
|
||||
request: fastapi.Request = None,
|
||||
) -> "RequestContext":
|
||||
"""Creates a new RequestContext instance, applying default guardrails if needed."""
|
||||
|
||||
@@ -71,15 +80,36 @@ class RequestContext:
|
||||
logging_guardrails=[],
|
||||
)
|
||||
|
||||
# if additionally provided, extract separate API key to use with guardrailing service
|
||||
guardrail_service_authorization = None
|
||||
if (
|
||||
guardrail_authorization
|
||||
:= extract_guardrail_service_authorization_from_headers(request)
|
||||
):
|
||||
guardrail_service_authorization = guardrail_authorization
|
||||
|
||||
return cls(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrail_authorization=guardrail_service_authorization,
|
||||
guardrails=guardrails,
|
||||
config=context_config,
|
||||
_created_via_factory=True,
|
||||
)
|
||||
|
||||
def get_guardrailing_authorization(self) -> Optional[str]:
|
||||
"""
|
||||
Returns the authorization to use for the guardrailing service.
|
||||
|
||||
This can be different from the invariant authorization, but falls back
|
||||
"to be the same if not explicitly set via header.
|
||||
|
||||
See also extract_guardrail_service_authorization_from_headers(...)
|
||||
"""
|
||||
|
||||
return self.guardrail_authorization or self.invariant_authorization
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RequestContext("
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
||||
@@ -62,6 +64,11 @@ def create_annotations_from_guardrails_errors(
|
||||
return annotations
|
||||
|
||||
|
||||
def get_explorer_api_url() -> str:
|
||||
return "https://preview-explorer.invariantlabs.ai"
|
||||
return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
|
||||
|
||||
|
||||
async def push_trace(
|
||||
messages: List[List[Dict[str, Any]]],
|
||||
dataset_name: str,
|
||||
@@ -94,7 +101,7 @@ async def push_trace(
|
||||
metadata=metadata,
|
||||
)
|
||||
client = AsyncClient(
|
||||
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
|
||||
api_url=get_explorer_api_url().rstrip("/"),
|
||||
api_key=invariant_authorization.split("Bearer ")[1],
|
||||
)
|
||||
try:
|
||||
@@ -117,7 +124,7 @@ async def fetch_guardrails_from_explorer(
|
||||
# dataset details without requiring a username.
|
||||
|
||||
client = httpx.AsyncClient(
|
||||
base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
|
||||
base_url=get_explorer_api_url().rstrip("/"),
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
},
|
||||
@@ -125,7 +132,12 @@ async def fetch_guardrails_from_explorer(
|
||||
|
||||
# Get the user details.
|
||||
user_info_response = await client.get("/api/v1/user/identity")
|
||||
if user_info_response.status_code != 200:
|
||||
if user_info_response.status_code == 401:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid Invariant API key. Please check your API key.",
|
||||
)
|
||||
elif user_info_response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
|
||||
)
|
||||
|
||||
@@ -6,9 +6,13 @@ import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException
|
||||
import httpx
|
||||
from common.guardrails import Guardrail
|
||||
from common.request_context import RequestContext
|
||||
from common.authorization import (
|
||||
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
|
||||
)
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
@@ -96,12 +100,17 @@ async def preload_guardrails(context: "RequestContext") -> None:
|
||||
# Move these calls to a batch preload/validate API.
|
||||
for blocking_guardrail in context.guardrails.blocking_guardrails:
|
||||
task = asyncio.create_task(
|
||||
_preload(blocking_guardrail.content, context.invariant_authorization)
|
||||
_preload(
|
||||
blocking_guardrail.content, context.get_guardrailing_authorization()
|
||||
)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
for logging_guadrail in context.guardrails.logging_guardrails:
|
||||
task = asyncio.create_task(
|
||||
_preload(logging_guadrail.content, context.invariant_authorization)
|
||||
_preload(
|
||||
logging_guadrail.content,
|
||||
context.get_guardrailing_authorization(),
|
||||
)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
except Exception as e:
|
||||
@@ -332,7 +341,7 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]],
|
||||
guardrails: List[Guardrail],
|
||||
invariant_authorization: str,
|
||||
context: RequestContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
@@ -357,11 +366,18 @@ async def check_guardrails(
|
||||
"policies": [g.content for g in guardrails],
|
||||
},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Authorization": context.get_guardrailing_authorization(),
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
if not result.is_success:
|
||||
if result.status_code == 401:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format(
|
||||
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER
|
||||
),
|
||||
)
|
||||
raise Exception(
|
||||
f"Guardrails check failed: {result.status_code} - {result.text}"
|
||||
)
|
||||
@@ -392,6 +408,8 @@ async def check_guardrails(
|
||||
]
|
||||
}
|
||||
return aggregated_errors
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
# make sure runtime errors are also visible in e.g. Explorer
|
||||
|
||||
@@ -105,6 +105,7 @@ async def anthropic_v1_messages_gateway(
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=header_guardrails or dataset_guardrails,
|
||||
config=config,
|
||||
request=request,
|
||||
)
|
||||
if request_json.get("stream"):
|
||||
return await handle_streaming_response(context, client, anthropic_request)
|
||||
@@ -157,7 +158,7 @@ async def get_guardrails_check_result(
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_messages,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
context=context,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
@@ -64,11 +64,16 @@ async def gemini_generate_content_gateway(
|
||||
status_code=400,
|
||||
)
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS + [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() not in IGNORED_HEADERS + [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
|
||||
}
|
||||
headers["accept-encoding"] = "identity"
|
||||
invariant_authorization, gemini_api_key = extract_authorization_from_headers(
|
||||
request, dataset_name, GEMINI_AUTHORIZATION_HEADER, [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
|
||||
request,
|
||||
dataset_name,
|
||||
GEMINI_AUTHORIZATION_HEADER,
|
||||
[GEMINI_AUTHORIZATION_FALLBACK_HEADER],
|
||||
)
|
||||
headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key
|
||||
|
||||
@@ -98,6 +103,7 @@ async def gemini_generate_content_gateway(
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=header_guardrails or dataset_guardrails,
|
||||
config=config,
|
||||
request=request,
|
||||
)
|
||||
if alt == "sse" or endpoint == "streamGenerateContent":
|
||||
return await stream_response(
|
||||
@@ -394,7 +400,7 @@ async def get_guardrails_check_result(
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_requests + converted_responses,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
context=context,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ async def openai_chat_completions_gateway(
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=header_guardrails or dataset_guardrails,
|
||||
config=config,
|
||||
request=request,
|
||||
)
|
||||
if request_json.get("stream", False):
|
||||
return await handle_stream_response(
|
||||
@@ -546,7 +547,7 @@ async def get_guardrails_check_result(
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=messages,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
context=context,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ openai_client = OpenAI(
|
||||
"Invariant-Authorization": "Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
"Invariant-Guardrails": guardrails,
|
||||
},
|
||||
base_url="http://localhost:9999/api/v1/gateway/non-streaming/openai",
|
||||
base_url="http://localhost:8005/api/v1/gateway/non-streaming/openai",
|
||||
)
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
|
||||
@@ -7,6 +7,9 @@ import random
|
||||
import string
|
||||
import pytest
|
||||
|
||||
from gateway.common.config_manager import GatewayConfig
|
||||
from gateway.common.request_context import RequestContext
|
||||
|
||||
|
||||
# Add root folder (parent) to sys.path
|
||||
sys.path.append(
|
||||
@@ -15,11 +18,19 @@ sys.path.append(
|
||||
)
|
||||
)
|
||||
|
||||
from gateway.common.authorization import extract_authorization_from_headers, INVARIANT_AUTHORIZATION_HEADER, API_KEYS_SEPARATOR
|
||||
from gateway.common.authorization import (
|
||||
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
|
||||
extract_authorization_from_headers,
|
||||
INVARIANT_AUTHORIZATION_HEADER,
|
||||
API_KEYS_SEPARATOR,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("push_to_explorer", [True, False])
|
||||
@pytest.mark.parametrize("invariant_authorization", [True, False])
|
||||
@pytest.mark.parametrize("invariant_authorization_appended_to_llm_provider_api_key", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"invariant_authorization_appended_to_llm_provider_api_key", [True, False]
|
||||
)
|
||||
@pytest.mark.parametrize("use_fallback_header", [True, False])
|
||||
def test_extract_authorization_from_headers(
|
||||
push_to_explorer: bool,
|
||||
@@ -29,19 +40,23 @@ def test_extract_authorization_from_headers(
|
||||
):
|
||||
"""Test the extract_authorization_from_headers function."""
|
||||
|
||||
llm_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
inv_apikey = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
llm_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
inv_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
dataset_name = "test-dataset" if push_to_explorer else None
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
llm_provider_api_key = "fallback-header" if use_fallback_header else "llm-provider-api-key"
|
||||
llm_provider_api_key = (
|
||||
"fallback-header" if use_fallback_header else "llm-provider-api-key"
|
||||
)
|
||||
headers[llm_provider_api_key] = llm_apikey
|
||||
|
||||
if invariant_authorization:
|
||||
print("invariant_authorization - TRUE")
|
||||
if invariant_authorization_appended_to_llm_provider_api_key:
|
||||
headers[llm_provider_api_key] = f"{headers.get(llm_provider_api_key, '')}{API_KEYS_SEPARATOR}{inv_apikey}"
|
||||
headers[llm_provider_api_key] = (
|
||||
f"{headers.get(llm_provider_api_key, '')}{API_KEYS_SEPARATOR}{inv_apikey}"
|
||||
)
|
||||
else:
|
||||
headers[INVARIANT_AUTHORIZATION_HEADER] = f"Bearer {inv_apikey}"
|
||||
|
||||
@@ -50,7 +65,7 @@ def test_extract_authorization_from_headers(
|
||||
def __init__(self, headers):
|
||||
self.headers = headers
|
||||
|
||||
request = MockRequest(headers)
|
||||
request = MockRequest(headers)
|
||||
|
||||
# Call the function
|
||||
try:
|
||||
@@ -70,12 +85,79 @@ def test_extract_authorization_from_headers(
|
||||
raise e
|
||||
# Verify the results
|
||||
if invariant_authorization:
|
||||
if not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key:
|
||||
if (
|
||||
not push_to_explorer
|
||||
and invariant_authorization_appended_to_llm_provider_api_key
|
||||
):
|
||||
assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[0] == llm_apikey
|
||||
assert llm_provider_api_key.split(API_KEYS_SEPARATOR)[1] == inv_apikey
|
||||
else:
|
||||
assert invariant_auth == ("Bearer " + inv_apikey)
|
||||
else:
|
||||
assert invariant_auth is None
|
||||
if not(not push_to_explorer and invariant_authorization_appended_to_llm_provider_api_key):
|
||||
assert llm_provider_api_key == llm_apikey
|
||||
if not (
|
||||
not push_to_explorer
|
||||
and invariant_authorization_appended_to_llm_provider_api_key
|
||||
):
|
||||
assert llm_provider_api_key == llm_apikey
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_guardrailing_api_key", [True, False])
|
||||
def test_extract_guardrails_authorization_from_headers(use_guardrailing_api_key: bool):
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
inv_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
inv_guardrails_apikey = "".join(
|
||||
random.choices(string.ascii_letters + string.digits, k=10)
|
||||
)
|
||||
llm_apikey = "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
||||
|
||||
headers[INVARIANT_AUTHORIZATION_HEADER] = f"Bearer {inv_apikey}"
|
||||
headers["Authorization"] = f"Bearer {llm_apikey}"
|
||||
|
||||
if use_guardrailing_api_key:
|
||||
headers[INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER] = (
|
||||
f"Bearer {inv_guardrails_apikey}"
|
||||
)
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, headers):
|
||||
self.headers = headers
|
||||
|
||||
dataset_name = "test-dataset"
|
||||
request = MockRequest(headers)
|
||||
|
||||
try:
|
||||
invariant_authorization, llm_provider_api_key = (
|
||||
extract_authorization_from_headers(
|
||||
request,
|
||||
dataset_name=dataset_name,
|
||||
llm_provider_api_key_header="Authorization",
|
||||
)
|
||||
)
|
||||
|
||||
context = RequestContext.create(
|
||||
request_json={"input": "test"},
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=None,
|
||||
config=GatewayConfig(),
|
||||
request=request,
|
||||
)
|
||||
except HTTPException as e:
|
||||
# If an exception is raised, check if it is the expected one
|
||||
if not invariant_authorization:
|
||||
assert e.status_code == 400
|
||||
assert e.detail == "Missing invariant api key"
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Verify the results
|
||||
assert invariant_authorization == ("Bearer " + inv_apikey)
|
||||
assert llm_provider_api_key == llm_apikey
|
||||
assert context.get_guardrailing_authorization() == (
|
||||
"Bearer " + inv_guardrails_apikey
|
||||
if use_guardrailing_api_key
|
||||
else invariant_authorization
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user