Merge branch 'main' into guardrails-from-header

This commit is contained in:
Luca Beurer-Kellner
2025-04-02 17:28:05 +02:00
27 changed files with 1575 additions and 414 deletions
+29
View File
@@ -0,0 +1,29 @@
from openai import OpenAI
from httpx import Client
import os
# unicode escape everything
guardrails = """
raise "Rule 1: Do not talk about Fight Club" if:
(msg: Message)
"fight club" in msg.content
""".encode("unicode_escape")
openai_client = OpenAI(
default_headers={
"Invariant-Authorization": "Bearer " + os.getenv("INVARIANT_API_KEY"),
"Invariant-Guardrails": guardrails,
},
base_url="http://localhost:8000/api/v1/gateway/non-streaming/openai",
)
response = openai_client.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "user",
"content": "What can you tell me about fight club?",
}
],
)
print("Response: ", response.choices[0].message.content)
+3 -3
View File
@@ -1,4 +1,4 @@
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres
POSTGRES_DB=invariantmonitor
POSTGRES_HOST=database
POSTGRES_PASSWORD=postgres
POSTGRES_DB=invariantmonitor
POSTGRES_HOST=database
+13 -6
View File
@@ -8,8 +8,10 @@ API_KEYS_SEPARATOR = ";invariant-auth="
def extract_authorization_from_headers(
request: Request, dataset_name: Optional[str], llm_provider_api_key_header: str
) -> Tuple[str, str]:
request: Request,
dataset_name: Optional[str] = None,
llm_provider_api_key_header: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
Extracts the Invariant authorization and LLM Provider API key from the request headers.
@@ -26,8 +28,15 @@ def extract_authorization_from_headers(
The header in that case becomes:
{llm_provider_api_key_header}: "<API Key>;invariant-auth=<Invariant API Key>"
"""
# invariant api key
invariant_authorization = request.headers.get(INVARIANT_AUTHORIZATION_HEADER)
llm_provider_api_key = request.headers.get(llm_provider_api_key_header)
# llm provider api key
if llm_provider_api_key_header is not None:
llm_provider_api_key = request.headers.get(llm_provider_api_key_header)
else:
llm_provider_api_key = None
# if the dataset name is not None, we need to check if the invariant api key is present
if dataset_name:
if invariant_authorization is None:
if llm_provider_api_key is None:
@@ -43,9 +52,7 @@ def extract_authorization_from_headers(
API_KEYS_SEPARATOR
)
if len(api_keys) != 2 or not api_keys[1].strip():
raise HTTPException(
status_code=400, detail="Invalid API Key format"
)
raise HTTPException(status_code=400, detail="Invalid API Key format")
invariant_authorization = f"Bearer {api_keys[1].strip()}"
llm_provider_api_key = f"{api_keys[0].strip()}"
+22 -13
View File
@@ -8,6 +8,9 @@ from typing import Optional
import fastapi
from httpx import HTTPStatusError
from common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet
from common.authorization import extract_authorization_from_headers
def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]:
"""
@@ -29,8 +32,8 @@ def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[
class GatewayConfig:
"""Common configurations for the Gateway Server."""
def __init__(self, guardrails: Optional[str] = None):
self.guardrails = guardrails or self._load_guardrails_from_file()
def __init__(self):
self.guardrails = self._load_guardrails_from_file()
def _load_guardrails_from_file(self) -> str:
"""
@@ -67,13 +70,7 @@ class GatewayConfig:
raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e
def __repr__(self) -> str:
return f"GatewayConfig(guardrails={repr(self.guardrails)})"
def with_guardrails(self, guardrails: str) -> "GatewayConfig":
"""
Returns a new GatewayConfig instance with the specified guardrails.
"""
return GatewayConfig(guardrails)
return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})"
class GatewayConfigManager:
@@ -94,8 +91,20 @@ class GatewayConfigManager:
local_config = GatewayConfig()
cls._config_instance = local_config
# if provided in header, use custom guardrailing policy
if guardrail_file_contents := extract_policy_from_headers(request):
local_config = local_config.with_guardrails(guardrail_file_contents)
return local_config
async def GuardrailsInHeader(request: fastapi.Request) -> Optional[GuardrailRuleSet]:
# if provided in header, use custom guardrailing policy
if guardrails := extract_policy_from_headers(request):
return GuardrailRuleSet(
blocking_guardrails=[
Guardrail(
id="guardrail-from-header",
name="guardrails from request header",
content=guardrails,
action=GuardrailAction.BLOCK,
)
],
logging_guardrails=[],
)
+31
View File
@@ -0,0 +1,31 @@
"""Common guardrails data class."""
from enum import Enum
from typing import List
from dataclasses import dataclass
class GuardrailAction(str, Enum):
"""Enum representing the action to be taken for guardrail rules."""
BLOCK = "block"
LOG = "log"
@dataclass(frozen=True)
class Guardrail:
"""Represents a single guardrail rule."""
id: str
name: str
content: str
action: GuardrailAction
@dataclass(frozen=True)
class GuardrailRuleSet:
"""Grouped guardrail rules separated by their action."""
blocking_guardrails: List[Guardrail]
logging_guardrails: List[Guardrail]
+93
View File
@@ -0,0 +1,93 @@
"""Common Request context data class."""
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from common.config_manager import GatewayConfig
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
@dataclass(frozen=True)
class RequestContext:
"""Structured context for a request. Must be created via `RequestContext.create()`."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
invariant_authorization: Optional[str] = None
# the set of guardrails to enforce for this request
guardrails: Optional[GuardrailRuleSet] = None
config: Dict[str, Any] = None
_created_via_factory: bool = field(
default=False, init=True, repr=False, compare=False
)
def __post_init__(self):
if not self._created_via_factory:
raise RuntimeError(
"RequestContext must be created using RequestContext.create()"
)
@classmethod
def create(
cls,
request_json: Dict[str, Any],
dataset_name: Optional[str] = None,
invariant_authorization: Optional[str] = None,
guardrails: Optional[GuardrailRuleSet] = None,
config: Optional[GatewayConfig] = None,
) -> "RequestContext":
"""Creates a new RequestContext instance, applying default guardrails if needed."""
# Convert GatewayConfig to a basic dict, excluding guardrails_from_file
context_config = {
key: value
for key, value in (config.__dict__.items() if config else {})
if key != "guardrails_from_file"
}
# If no guardrails are configured for the dataset on Explorer,
# and the config specifies guardrails_from_file, use that.
guardrails = guardrails
if (
(
not guardrails
or (
not guardrails.blocking_guardrails
and not guardrails.logging_guardrails
)
)
and config
and config.guardrails_from_file
):
# TODO: Support logging guardrails via file.
guardrails = GuardrailRuleSet(
blocking_guardrails=[
Guardrail(
id="default",
name="default",
content=config.guardrails_from_file,
action=GuardrailAction.BLOCK,
)
],
logging_guardrails=[],
)
return cls(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrails=guardrails,
config=context_config,
_created_via_factory=True,
)
def __repr__(self) -> str:
return (
f"RequestContext("
f"request_json={self.request_json}, "
f"dataset_name={self.dataset_name}, "
f"invariant_authorization={self.invariant_authorization}, "
f"guardrails={self.guardrails}, "
f"config={self.config})"
)
-16
View File
@@ -1,16 +0,0 @@
"""Common Request context data class."""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from common.config_manager import GatewayConfig
@dataclass(frozen=True)
class RequestContextData:
"""Request context data class."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
invariant_authorization: Optional[str] = None
config: Optional[GatewayConfig] = None
+84 -2
View File
@@ -3,15 +3,18 @@
import os
from typing import Any, Dict, List
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
from invariant_sdk.types.annotations import AnnotationCreate
import httpx
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
def create_annotations_from_guardrails_errors(
guardrails_errors: List[dict],
guardrails_errors: List[dict], action: str = "block"
) -> List[AnnotationCreate]:
"""Create Explorer annotations from the guardrails errors."""
annotations = []
@@ -45,7 +48,10 @@ def create_annotations_from_guardrails_errors(
AnnotationCreate(
content=content,
address=r,
extra_metadata={"source": "guardrails-error"},
extra_metadata={
"source": "guardrails-error",
"guardrail-action": action,
},
)
)
return annotations
@@ -91,3 +97,79 @@ async def push_trace(
except Exception as e:
print(f"Failed to push trace: {e}")
return {"error": str(e)}
async def fetch_guardrails_from_explorer(
dataset_name: str, invariant_authorization: str
) -> GuardrailRuleSet:
"""Get the guardrails for the dataset.
Returns:
GuardrailRuleSet: The guardrails for the dataset grouped by their action.
"""
# TODO: Implement a single API in explorer backend which can return
# dataset details without requiring a username.
client = httpx.AsyncClient(
base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
headers={
"Authorization": invariant_authorization,
},
)
# Get the user details.
user_info_response = await client.get("/api/v1/user/info")
if user_info_response.status_code != 200:
raise ValueError(
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
)
user_details = user_info_response.json()
username = user_details["username"]
# Get the dataset policies.
policies_response = await client.get(
f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy"
)
if policies_response.status_code != 200:
if policies_response.status_code == 404:
# If the dataset does not exist, return empty guardrails.
return GuardrailRuleSet(
blocking_guardrails=[],
logging_guardrails=[],
)
raise ValueError(
f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}"
)
policies_details = policies_response.json()
guardrails = policies_details.get("policies", [])
blocking_guardrails = []
logging_guardrails = []
for g in guardrails:
action = g["action"]
if not g["enabled"]:
# Skip guardrails that are not enabled.
continue
if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG):
print("[Warning] Skipping unknown guardrail action: ", action)
continue
guardrail = Guardrail(
id=g["id"],
name=g["name"],
content=g["content"],
action=GuardrailAction(action),
)
if action == GuardrailAction.BLOCK:
blocking_guardrails.append(guardrail)
else:
logging_guardrails.append(guardrail)
return GuardrailRuleSet(
blocking_guardrails=blocking_guardrails,
logging_guardrails=logging_guardrails,
)
+40 -14
View File
@@ -7,7 +7,8 @@ from typing import Any, Dict, List
from functools import wraps
import httpx
from common.request_context_data import RequestContextData
from common.guardrails import Guardrail
from common.request_context import RequestContext
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -81,21 +82,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None:
result.raise_for_status()
async def preload_guardrails(context: "RequestContextData") -> None:
async def preload_guardrails(context: "RequestContext") -> None:
"""
Preloads the guardrails for faster checking later.
Args:
context: RequestContextData object.
context: RequestContext object.
"""
if not context.config or not context.config.guardrails:
if not context.guardrails:
return
try:
task = asyncio.create_task(
_preload(context.config.guardrails, context.invariant_authorization)
)
asyncio.shield(task)
# Move these calls to a batch preload/validate API.
for blocking_guardrail in context.guardrails.blocking_guardrails:
task = asyncio.create_task(
_preload(blocking_guardrail.content, context.invariant_authorization)
)
asyncio.shield(task)
for logging_guadrail in context.guardrails.logging_guardrails:
task = asyncio.create_task(
_preload(logging_guadrail.content, context.invariant_authorization)
)
asyncio.shield(task)
except Exception as e:
print(f"Error scheduling preload_guardrails task: {e}")
@@ -322,14 +330,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
async def check_guardrails(
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
messages: List[Dict[str, Any]],
guardrails: List[Guardrail],
invariant_authorization: str,
) -> Dict[str, Any]:
"""
Checks guardrails on the list of messages.
This calls the batch check API of the Guardrails service.
Args:
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (str): The guardrails to check against.
guardrails (List[Guardrail]): The guardrails to check against.
invariant_authorization (str): Value of the
invariant-authorization header.
@@ -340,8 +351,11 @@ async def check_guardrails(
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
try:
result = await client.post(
f"{url}/api/v1/policy/check",
json={"messages": messages, "policy": guardrails},
f"{url}/api/v1/policy/check/batch",
json={
"messages": messages,
"policies": [g.content for g in guardrails],
},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
@@ -351,8 +365,20 @@ async def check_guardrails(
raise Exception(
f"Guardrails check failed: {result.status_code} - {result.text}"
)
print(f"Guardrail check response: {result.json()}")
return result.json()
guardrails_result = result.json()
aggregated_errors = {"errors": []}
for res in guardrails_result.get("result", []):
aggregated_errors["errors"].extend(res.get("errors", []))
# check for any error_message
if error_message := res.get("error_message"):
return {
"errors": [
{"args": [error_message], "kwargs": {}, "ranges": []}
]
}
return aggregated_errors
except Exception as e:
print(f"Failed to verify guardrails: {e}")
# make sure runtime errors are also visible in e.g. Explorer
+92 -47
View File
@@ -5,20 +5,29 @@ import json
from typing import Any, Optional
import httpx
from regex import R
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from common.authorization import extract_authorization_from_headers
from common.config_manager import (
GatewayConfig,
GatewayConfigManager,
GuardrailsInHeader,
)
from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from common.guardrails import GuardrailAction, GuardrailRuleSet
from common.request_context import RequestContext
from converters.anthropic_to_invariant import (
convert_anthropic_to_invariant_message_format,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
push_trace,
)
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
@@ -61,6 +70,7 @@ async def anthropic_v1_messages_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
):
"""Proxy calls to the Anthropic APIs"""
headers = {
@@ -83,21 +93,26 @@ async def anthropic_v1_messages_gateway(
data=request_body,
)
context = RequestContextData(
dataset_guardrails = None
if dataset_name:
# Get the guardrails for the dataset from explorer.
dataset_guardrails = await fetch_guardrails_from_explorer(
dataset_name, invariant_authorization
)
context = RequestContext.create(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrails=header_guardrails or dataset_guardrails,
config=config,
)
asyncio.create_task(preload_guardrails(context))
if request_json.get("stream"):
return await handle_streaming_response(context, client, anthropic_request)
return await handle_non_streaming_response(context, client, anthropic_request)
def create_metadata(
context: RequestContextData, response_json: dict[str, Any]
context: RequestContext, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Creates metadata for the trace"""
metadata = {k: v for k, v in context.request_json.items() if k != "messages"}
@@ -108,7 +123,7 @@ def create_metadata(
def combine_request_and_response_messages(
context: RequestContextData, json_response: dict[str, Any]
context: RequestContext, response_json: dict[str, Any]
):
"""Combine the request and response messages"""
messages = []
@@ -117,42 +132,63 @@ def combine_request_and_response_messages(
{"role": "system", "content": context.request_json.get("system")}
)
messages.extend(context.request_json.get("messages", []))
if len(json_response) > 0:
messages.append(json_response)
if len(response_json) > 0:
messages.append(response_json)
return messages
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any]
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = combine_request_and_response_messages(context, json_response)
# Determine which guardrails to apply based on the action
guardrails = (
context.guardrails.logging_guardrails
if action == GuardrailAction.LOG
else context.guardrails.blocking_guardrails
)
if not guardrails:
return {}
messages = combine_request_and_response_messages(context, response_json)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=converted_messages,
guardrails=context.config.guardrails,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
async def push_to_explorer(
context: RequestContextData,
context: RequestContext,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
guardrails_execution_result = guardrails_execution_result or {}
annotations = create_annotations_from_guardrails_errors(
guardrails_execution_result.get("errors", [])
guardrails_execution_result.get("errors", []), action="block"
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
# Combine the messages from the request body and Anthropic response
messages = combine_request_and_response_messages(context, merged_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=context.dataset_name,
messages=[converted_messages],
@@ -163,30 +199,32 @@ async def push_to_explorer(
class InstrumentedAnthropicResponse(InstrumentedResponse):
"""Instrumented response for Anthropic API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
# response data
self.response: Optional[httpx.Response] = None
self.response_string: Optional[str] = None
self.json_response: Optional[dict[str, Any]] = None
self.response_json: Optional[dict[str, Any]] = None
# guardrailing response (if any)
self.guardrails_execution_result = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
self.context, action=GuardrailAction.BLOCK, response_json={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -220,10 +258,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
)
async def request(self):
"""Make the request to the Anthropic API."""
self.response = await self.client.send(self.anthropic_request)
try:
json_response = self.response.json()
response_json = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
@@ -232,11 +271,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
detail=response_json.get("error", "Unknown error from Anthropic"),
)
self.json_response = json_response
self.response_string = json.dumps(json_response)
self.response_json = response_json
self.response_string = json.dumps(response_json)
return self._make_response(
content=self.response_string,
@@ -261,13 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
# ensure the response data is available
assert self.response is not None, "response is None"
assert self.json_response is not None, "json_response is None"
assert self.response_json is not None, "response_json is None"
assert self.response_string is not None, "response_string is None"
if self.context.config and self.context.config.guardrails:
if self.context.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.response_json,
)
if guardrails_execution_result.get("errors", []):
guardrail_response_string = json.dumps(
@@ -283,7 +324,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.response_json,
guardrails_execution_result,
)
)
@@ -300,13 +341,13 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context, self.json_response, guardrails_execution_result
self.context, self.response_json, guardrails_execution_result
)
)
async def handle_non_streaming_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
) -> Response:
@@ -320,17 +361,19 @@ async def handle_non_streaming_response(
return await response.instrumented_request()
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
"""Instrumented streaming response for Anthropic API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
@@ -342,9 +385,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -392,6 +437,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
yield chunk
async def on_chunk(self, chunk):
"""Process the chunk and update the merged_response"""
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
return
@@ -400,14 +446,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
process_chunk(decoded_chunk, self.merged_response)
# on last stream chunk, run output guardrails
if (
"event: message_stop" in decoded_chunk
and self.context.config
and self.context.config.guardrails
):
if "event: message_stop" in decoded_chunk and self.context.guardrails:
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -420,7 +464,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
}
)
# yield an extra error chunk (without preventing the original chunk to go through after,
# yield an extra error chunk (without preventing the original chunk
# to go through after,
# so client gets the proper message_stop event still)
return ExtraItem(
value=f"event: error\ndata: {error_chunk}\n\n".encode()
@@ -440,12 +485,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
async def handle_streaming_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
response = InstrumentedAnthropicStreamingResposne(
response = InstrumentedAnthropicStreamingResponse(
context=context,
client=client,
anthropic_request=anthropic_request,
+84 -31
View File
@@ -5,16 +5,27 @@ import json
from typing import Any, Literal, Optional
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
from fastapi.responses import StreamingResponse
from common.authorization import extract_authorization_from_headers
from common.config_manager import (
GatewayConfig,
GatewayConfigManager,
GuardrailsInHeader,
)
from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from common.guardrails import GuardrailAction, GuardrailRuleSet
from common.request_context import RequestContext
from converters.gemini_to_invariant import convert_request, convert_response
from integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
push_trace,
)
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
@@ -23,8 +34,6 @@ from integrations.guardrails import (
preload_guardrails,
check_guardrails,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardrails import check_guardrails, preload_guardrails
gateway = APIRouter()
@@ -43,6 +52,7 @@ async def gemini_generate_content_gateway(
None, title="Response Format", description="Set to 'sse' for streaming"
),
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
) -> Response:
"""Proxy calls to the Gemini GenerateContent API"""
if endpoint not in ["generateContent", "streamGenerateContent"]:
@@ -76,14 +86,19 @@ async def gemini_generate_content_gateway(
headers=headers,
)
context = RequestContextData(
dataset_guardrails = None
if dataset_name:
# Get the guardrails for the dataset
dataset_guardrails = await fetch_guardrails_from_explorer(
dataset_name, invariant_authorization
)
context = RequestContext.create(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrails=header_guardrails or dataset_guardrails,
config=config,
)
asyncio.create_task(preload_guardrails(context))
if alt == "sse" or endpoint == "streamGenerateContent":
return await stream_response(
context,
@@ -98,16 +113,18 @@ async def gemini_generate_content_gateway(
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
"""Instrumented streaming response for Gemini API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
@@ -124,6 +141,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
location: Literal["request", "response"],
guardrails_execution_result: dict[str, Any],
) -> dict:
"""Create a refusal response for the given request or response"""
return {
"candidates": [
{
@@ -157,10 +175,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
"""
Check guardrails in a pipelined fashion, before processing the first chunk
(for input guardrailing).
"""
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
self.context, action=GuardrailAction.BLOCK, response_json={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -184,6 +205,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
)
async def event_generator(self):
"""Event generator for streaming responses"""
response = await self.client.send(self.gemini_request, stream=True)
if response.status_code != 200:
@@ -199,6 +221,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
yield chunk
async def on_chunk(self, chunk):
"""Processes each chunk of the streaming response"""
chunk_text = chunk.decode().strip()
if not chunk_text:
return
@@ -210,12 +233,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
if (
self.merged_response.get("candidates", [])
and self.merged_response.get("candidates")[0].get("finishReason", "")
and self.context.config
and self.context.config.guardrails
and self.context.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -254,7 +278,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
async def stream_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
) -> Response:
@@ -269,7 +293,6 @@ async def stream_response(
async def event_generator():
async for chunk in response.instrumented_event_generator():
yield chunk
print("chunk", chunk)
return StreamingResponse(
event_generator(),
@@ -332,7 +355,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) ->
def create_metadata(
context: RequestContextData, response_json: dict[str, Any]
context: RequestContext, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Creates metadata for the trace"""
metadata = {
@@ -352,32 +375,53 @@ def create_metadata(
async def get_guardrails_check_result(
context: RequestContextData, response_json: dict[str, Any]
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
# Determine which guardrails to apply based on the action
guardrails = (
context.guardrails.logging_guardrails
if action == GuardrailAction.LOG
else context.guardrails.blocking_guardrails
)
if not guardrails:
return {}
converted_requests = convert_request(context.request_json)
converted_responses = convert_response(response_json)
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=converted_requests + converted_responses,
guardrails=context.config.guardrails,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
async def push_to_explorer(
context: RequestContextData,
context: RequestContext,
response_json: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
guardrails_execution_result = guardrails_execution_result or {}
annotations = create_annotations_from_guardrails_errors(
guardrails_execution_result.get("errors", [])
guardrails_execution_result.get("errors", []), action="block"
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=response_json,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
converted_requests = convert_request(context.request_json)
converted_responses = convert_response(response_json)
@@ -391,16 +435,18 @@ async def push_to_explorer(
class InstrumentedGeminiResponse(InstrumentedResponse):
"""Instrumented response for Gemini API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
@@ -412,10 +458,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
self.guardrails_execution_result: Optional[dict[str, Any]] = None
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
"""
Check guardrails in a pipelined fashion, before processing the first chunk
(for input guardrailing).
"""
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
self.context, action=GuardrailAction.BLOCK, response_json={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -463,6 +512,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
)
async def request(self):
"""Makes the request to the Gemini API and return the response"""
self.response = await self.client.send(self.gemini_request)
response_string = self.response.text
@@ -489,13 +539,16 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
)
async def on_end(self):
"""Runs when the request ends."""
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
if self.context.config and self.context.config.guardrails:
if self.context.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.response_json
self.context,
action=GuardrailAction.BLOCK,
response_json=self.response_json,
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
@@ -539,7 +592,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
async def handle_non_streaming_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
) -> Response:
+102 -52
View File
@@ -5,14 +5,26 @@ import json
from typing import Any, Optional
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from common.authorization import extract_authorization_from_headers
from common.config_manager import (
GatewayConfig,
GatewayConfigManager,
GuardrailsInHeader,
)
from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from common.guardrails import GuardrailAction, GuardrailRuleSet
from common.request_context import RequestContext
from integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
push_trace,
)
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
@@ -20,8 +32,6 @@ from integrations.guardrails import (
check_guardrails,
preload_guardrails,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
gateway = APIRouter()
@@ -48,6 +58,7 @@ async def openai_chat_completions_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
) -> Response:
"""Proxy calls to the OpenAI APIs"""
headers = {
@@ -71,14 +82,19 @@ async def openai_chat_completions_gateway(
headers=headers,
)
context = RequestContextData(
dataset_guardrails = None
if dataset_name:
# Get the guardrails for the dataset
dataset_guardrails = await fetch_guardrails_from_explorer(
dataset_name, invariant_authorization
)
context = RequestContext.create(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrails=header_guardrails or dataset_guardrails,
config=config,
)
asyncio.create_task(preload_guardrails(context))
if request_json.get("stream", False):
return await handle_stream_response(
context,
@@ -91,19 +107,20 @@ async def openai_chat_completions_gateway(
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
"""
Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
Does a streaming OpenAI completion request at the core, but also checks guardrails
before (concurrent) and after the request.
"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
@@ -130,10 +147,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
self.tool_call_mapping_by_index = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
"""
Check guardrails in a pipelined fashion, before processing the first chunk
(for input guardrailing).
"""
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -163,6 +185,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
)
async def on_chunk(self, chunk):
"""Processes each chunk of the stream and checks guardrails at the end of the stream"""
# process and check each chunk
chunk_text = chunk.decode().strip()
if not chunk_text:
@@ -178,14 +201,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
)
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
if (
"data: [DONE]" in chunk_text
and self.context.config
and self.context.config.guardrails
):
if "data: [DONE]" in chunk_text and self.context.guardrails:
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -203,7 +224,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
# push will happen in on_end
async def on_end(self):
"""Sends full merged response to the exploree."""
"""Sends full merged response to the explorer."""
# don't block on the response from explorer (.create_task)
if self.context.dataset_name:
asyncio.create_task(
@@ -213,10 +234,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
)
async def event_generator(self):
"""
Actual OpenAI stream response.
"""
"""Actual OpenAI stream response."""
response = await self.client.send(self.open_ai_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
@@ -233,7 +251,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
async def handle_stream_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
@@ -388,7 +406,7 @@ def update_existing_choice_with_delta(
def create_metadata(
context: RequestContextData, merged_response: dict[str, Any]
context: RequestContext, merged_response: dict[str, Any]
) -> dict[str, Any]:
"""Creates metadata for the trace"""
metadata = {
@@ -408,7 +426,7 @@ def create_metadata(
async def push_to_explorer(
context: RequestContextData,
context: RequestContext,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
@@ -417,12 +435,26 @@ async def push_to_explorer(
# or if the guardrails check returned errors.
guardrails_execution_result = guardrails_execution_result or {}
guardrails_errors = guardrails_execution_result.get("errors", [])
if guardrails_errors or not (
annotations = create_annotations_from_guardrails_errors(
guardrails_errors, action="block"
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
if annotations or not (
merged_response.get("choices")
and merged_response["choices"][0].get("finish_reason")
not in FINISH_REASON_TO_PUSH_TRACE
):
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
# Combine the messages from the request body and the choices from the OpenAI response
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in merged_response.get("choices", [])]
@@ -436,18 +468,29 @@ async def push_to_explorer(
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any] | None = None
context: RequestContext,
action: GuardrailAction,
response_json: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = list(context.request_json.get("messages", []))
# Determine which guardrails to apply based on the action
guardrails = (
context.guardrails.logging_guardrails
if action == GuardrailAction.LOG
else context.guardrails.blocking_guardrails
)
if json_response is not None:
messages += [choice["message"] for choice in json_response.get("choices", [])]
if not guardrails:
return {}
messages = list(context.request_json.get("messages", []))
if response_json is not None:
messages += [choice["message"] for choice in response_json.get("choices", [])]
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=messages,
guardrails=context.config.guardrails,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
@@ -455,35 +498,39 @@ async def get_guardrails_check_result(
class InstrumentedOpenAIResponse(InstrumentedResponse):
"""
Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
Does an OpenAI completion request at the core, but also checks guardrails
before (concurrent) and after the request.
"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
# request outputs
self.response: Optional[httpx.Response] = None
self.json_response: Optional[dict[str, Any]] = None
self.response_json: Optional[dict[str, Any]] = None
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
async def on_start(self):
"""Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)"""
if self.context.config and self.context.config.guardrails:
"""
Checks guardrails in a pipelined fashion, before processing
the first chunk (for input guardrailing)
"""
if self.context.guardrails:
# block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context
self.context, action=GuardrailAction.BLOCK
)
if self.guardrails_execution_result.get("errors", []):
# Push annotated trace to the explorer - don't block on its response
@@ -516,7 +563,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
self.response = await self.client.send(self.open_ai_request)
try:
self.json_response = self.response.json()
self.response_json = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
@@ -525,10 +572,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
detail=self.response_json.get("error", "Unknown error from OpenAI API"),
)
response_string = json.dumps(self.json_response)
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
return Response(
@@ -541,23 +588,26 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
async def on_end(self):
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
# these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed)
# these two request outputs are guaranteed to be available by the time we reach
# this point (after self.request() was executed)
# nevertheless, we check for them to avoid any potential issues
assert (
self.response is not None
), "on_end called before 'self.response' was available"
assert (
self.json_response is not None
), "on_end called before 'self.json_response' was available"
self.response_json is not None
), "on_end called before 'self.response_json' was available"
# extract original response status code
response_code = self.response.status_code
# if we have guardrails, check the response
if self.context.config and self.context.config.guardrails:
if self.context.guardrails:
# run guardrails again, this time on request + response
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.response_json,
)
if self.guardrails_execution_result.get("errors", []):
response_string = json.dumps(
@@ -573,7 +623,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.response_json,
self.guardrails_execution_result,
)
)
@@ -592,7 +642,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.response_json,
# include any guardrailing errors if available
self.guardrails_execution_result,
)
@@ -600,7 +650,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
async def handle_non_stream_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
+6 -1
View File
@@ -93,7 +93,12 @@ integration_tests() {
fi
echo "File successfully downloaded: $FILE"
TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py"
if [[ -z "$INVARIANT_API_KEY" ]]; then
echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests."
exit 1
fi
TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/integration_test_guardrails_via_file.py"
if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then
if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then
TEST_GUARDRAILS_FILE_PATH=$(realpath "$TEST_GUARDRAILS_FILE_PATH")
@@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
"""Test the Anthropic gateway with Invariant key in the Anthropic key"""
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
with patch.dict(
os.environ,
{
"ANTHROPIC_API_KEY": anthropic_api_key
+ ";invariant-auth=<not needed for test>"
},
{"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix},
):
client = anthropic.Anthropic(
http_client=Client(),
@@ -12,10 +12,11 @@ from typing import Dict, List
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_anthropic_client
import anthropic
import pytest
import requests
from httpx import Client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -26,14 +27,8 @@ class WeatherAgent:
def __init__(self, gateway_url, push_to_explorer):
self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
self.client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{gateway_url}/api/v1/gateway/{self.dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
self.client = get_anthropic_client(
gateway_url, push_to_explorer, self.dataset_name
)
self.get_weather_function = {
"name": "get_weather",
@@ -8,10 +8,10 @@ import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
from utils import get_anthropic_client
import pytest
import requests
from httpx import Client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -26,15 +26,10 @@ async def test_response_without_tool_call(
):
"""Test the Anthropic gateway without tool calling."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
client = get_anthropic_client(
gateway_url,
push_to_explorer,
dataset_name,
)
cities = ["zurich", "new york", "london"]
@@ -91,16 +86,7 @@ async def test_streaming_response_without_tool_call(
):
"""Test the Anthropic gateway without tool calling."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
)
client = get_anthropic_client(gateway_url, push_to_explorer, dataset_name)
cities = ["zurich", "new york", "london"]
queries = [
@@ -60,6 +60,7 @@ services:
app-api:
container_name: invariant-gateway-test-explorer-app-api
image: ghcr.io/invariantlabs-ai/explorer/app-api:latest
pull_policy: always
platform: linux/amd64
depends_on:
database:
@@ -8,9 +8,10 @@ import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_gemini_client
import pytest
import requests
from google import genai
from google.genai import types
# Pytest plugins
@@ -143,18 +144,7 @@ async def test_generate_content_with_tool_call(
without streaming.
"""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"invariant-authorization": "Bearer <some-key>"
}, # This key is not used for local tests
},
)
client = get_gemini_client(gateway_url, push_to_explorer, dataset_name)
request = {
"model": "gemini-2.0-flash",
@@ -10,6 +10,8 @@ from unittest.mock import patch
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_gemini_client
import pytest
import PIL.Image
import requests
@@ -29,17 +31,8 @@ async def test_generate_content(
):
"""Test the generate content gateway calls without tool calling."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"invariant-authorization": "Bearer <some-key>"
}, # This key is not used for local tests
},
)
client = get_gemini_client(gateway_url, push_to_explorer, dataset_name)
request = {
"model": "gemini-2.0-flash",
"contents": "What is the capital of France?",
@@ -115,18 +108,8 @@ async def test_generate_content_with_image(
):
"""Test that generate content gateway calls work with image."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = get_gemini_client(gateway_url, push_to_explorer, dataset_name)
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"invariant-authorization": "Bearer <some-key>"
}, # This key is not used for local tests
},
)
image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png"
image = PIL.Image.open(image_path)
@@ -181,9 +164,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
"""Test the generate content gateway calls with the Invariant API Key in the Gemini Key header."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
gemini_api_key = os.getenv("GEMINI_API_KEY")
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
with patch.dict(
os.environ,
{"GEMINI_API_KEY": gemini_api_key + ";invariant-auth=<not needed for test>"},
{"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix},
):
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
@@ -194,14 +178,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
chat_response = client.models.generate_content(
model="gemini-2.0-flash",
contents="What is the capital of Spain?",
contents="What is the capital of Denmark?",
config={
"maxOutputTokens": 100,
},
)
# Verify the chat response
assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper()
assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper()
expected_assistant_message = chat_response.candidates[0].content.parts[0].text
# Wait for the trace to be saved
@@ -228,7 +212,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
assert trace["messages"] == [
{
"role": "user",
"content": [{"text": "What is the capital of Spain?", "type": "text"}],
"content": [{"text": "What is the capital of Denmark?", "type": "text"}],
},
{
"role": "assistant",
@@ -8,10 +8,11 @@ import time
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset
import pytest
import requests
from httpx import Client
from anthropic import Anthropic, APIStatusError, BadRequestError
from anthropic import APIStatusError, BadRequestError
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -32,16 +33,10 @@ async def test_message_content_guardrail_from_file(
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
client = get_anthropic_client(
gateway_url,
push_to_explorer,
dataset_name,
)
request = {
@@ -161,16 +156,10 @@ async def test_tool_call_guardrail_from_file(
}
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
client = get_anthropic_client(
gateway_url,
push_to_explorer,
dataset_name,
)
if not do_stream:
@@ -255,16 +244,10 @@ async def test_input_from_guardrail_from_file(
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
client = get_anthropic_client(
gateway_url,
push_to_explorer,
dataset_name,
)
request = {
@@ -332,3 +315,276 @@ async def test_input_from_guardrail_from_file(
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("do_stream", [True, False])
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = get_anthropic_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"',
action="block",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Ask about the capital of Spain
# This should not be blocked by the guardrails from the explorer when we push to explorer
# because the file based guardrails are overridden by the explorer guardrails
spain_request = {
"model": "claude-3-5-sonnet-20241022",
"messages": [{"role": "user", "content": "What is the capital of Spain?"}],
"max_tokens": 100,
}
if not do_stream:
chat_response = client.messages.create(
**spain_request,
stream=False,
)
assert "Madrid" in chat_response.content[0].text
else:
chat_response = client.messages.create(
**spain_request,
stream=True,
)
merged_content = ""
for chunk in chat_response:
if chunk.type == "content_block_delta":
merged_content += chunk.delta.text
assert "Madrid" in merged_content
# Ask about Shrek
# This should be blocked by the guardrails from the explorer
user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details."
shrek_request = {
"model": "claude-3-5-sonnet-20241022",
"messages": [
{
"role": "user",
"content": user_prompt,
}
],
"max_tokens": 100,
}
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.messages.create(
**shrek_request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
)
# Only the block guardrail should be triggered here
assert "ogre detected in response" in str(exc_info.value)
assert "Fiona detected in response" not in str(exc_info.value)
else:
with pytest.raises(APIStatusError) as exc_info:
chat_response = client.messages.create(
**shrek_request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
)
# Only the block guardrail should be triggered here
assert "ogre detected in response" in str(exc_info.value)
assert "Fiona detected in response" not in str(exc_info.value)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 2
trace_id = traces[1]["id"]
# Fetch the second trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2
assert trace["messages"][0] == {
"role": "user",
"content": user_prompt,
}
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 2
assert (
annotations[0]["content"] == "ogre detected in response"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
)
assert (
annotations[1]["content"] == "Fiona detected in response"
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize(
"do_stream, is_block_action",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_preguardrailing_with_guardrails_from_explorer(
explorer_api_url, gateway_url, do_stream, is_block_action
):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = get_anthropic_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
action="block" if is_block_action else "log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
user_prompt = "Tell me a one sentence pun."
request = {
"model": "claude-3-5-sonnet-20241022",
"messages": [
{
"role": "user",
"content": user_prompt,
}
],
"max_tokens": 100,
}
if is_block_action:
if do_stream:
with pytest.raises(APIStatusError) as exc_info:
chat_response = client.messages.create(
**request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
else:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.messages.create(
**request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
_ = client.messages.create(
**request,
stream=True,
)
else:
_ = client.messages.create(
**request,
stream=False,
)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2 if not is_block_action else 1
assert trace["messages"][0] == {
"role": "user",
"content": user_prompt,
}
if not is_block_action:
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
if is_block_action
else "log"
)
@@ -8,9 +8,10 @@ import time
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset
import pytest
import requests
from httpx import Client
from google import genai
# Pytest plugins
@@ -30,17 +31,10 @@ async def test_message_content_guardrail_from_file(
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"headers": {
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
},
client = get_gemini_client(
gateway_url,
push_to_explorer,
dataset_name,
)
request = {
@@ -141,17 +135,10 @@ async def test_tool_call_guardrail_from_file(
)
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"headers": {
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
},
client = get_gemini_client(
gateway_url,
push_to_explorer,
dataset_name,
)
request = {
@@ -244,17 +231,10 @@ async def test_input_from_guardrail_from_file(
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"headers": {
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
},
client = get_gemini_client(
gateway_url,
push_to_explorer,
dataset_name,
)
request = {
@@ -323,6 +303,259 @@ async def test_input_from_guardrail_from_file(
)
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
@pytest.mark.parametrize("do_stream", [True, False])
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = get_gemini_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"',
action="block",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Ask about the capital of Spain
# This should not be blocked by the guardrails from the explorer when we push to explorer
# because the file based guardrails are overridden by the explorer guardrails
spain_request = {
"model": "gemini-2.0-flash",
"contents": "What is the capital of Spain?",
"config": {
"maxOutputTokens": 100,
},
}
if not do_stream:
chat_response = client.models.generate_content(**spain_request)
assert "Madrid" in chat_response.candidates[0].content.parts[0].text
else:
chat_response = client.models.generate_content_stream(**spain_request)
merged_content = ""
for chunk in chat_response:
if (
chunk.candidates
and chunk.candidates[0].content
and chunk.candidates[0].content.parts
):
for text_part in chunk.candidates[0].content.parts:
merged_content += text_part.text
assert "Madrid" in merged_content
# Ask about Shrek
# This should be blocked by the guardrails from the explorer
user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details."
shrek_request = {
"model": "gemini-2.0-flash",
"contents": user_prompt,
"config": {
"maxOutputTokens": 100,
},
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
client.models.generate_content(**shrek_request)
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
)
# Only the block guardrail should be triggered here
assert "ogre detected in response" in str(exc_info.value)
assert "Fiona detected in response" not in str(exc_info.value)
else:
response = client.models.generate_content_stream(**shrek_request)
assert_is_streamed_refusal(
response,
[
"[Invariant] The response did not pass the guardrails",
"ogre detected in response",
],
)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 2
trace_id = traces[1]["id"]
# Fetch the second trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2
assert trace["messages"][0] == {
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt,
}
],
}
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 2
assert (
annotations[0]["content"] == "ogre detected in response"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
)
assert (
annotations[1]["content"] == "Fiona detected in response"
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
)
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, is_block_action",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_preguardrailing_with_guardrails_from_explorer(
explorer_api_url, gateway_url, do_stream, is_block_action
):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = get_gemini_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
action="block" if is_block_action else "log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
user_prompt = "Tell me a one sentence pun."
request = {
"model": "gemini-2.0-flash",
"contents": user_prompt,
"config": {
"maxOutputTokens": 100,
},
}
if is_block_action:
if do_stream:
chat_response = client.models.generate_content_stream(**request)
assert_is_streamed_refusal(
chat_response,
[
"[Invariant] The request did not pass the guardrails",
"pun detected in user message",
],
)
else:
with pytest.raises(genai.errors.ClientError) as exc_info:
chat_response = client.models.generate_content(**request)
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
response = client.models.generate_content_stream(**request)
for _ in response:
pass
else:
_ = client.models.generate_content(**request)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2 if not is_block_action else 1
assert trace["messages"][0] == {
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt,
}
],
}
if not is_block_action:
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
if is_block_action
else "log"
)
def is_refusal(chunk):
return (
len(chunk.candidates) == 1
@@ -8,10 +8,11 @@ import time
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset
import pytest
import requests
from httpx import Client
from openai import OpenAI, BadRequestError, APIError
from openai import BadRequestError, APIError
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -30,17 +31,7 @@ async def test_message_content_guardrail_from_file(
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
request = {
"model": "gpt-4o",
@@ -161,17 +152,7 @@ async def test_tool_call_guardrail_from_file(
}
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
@@ -259,17 +240,7 @@ async def test_input_from_guardrail_from_file(
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
request = {
"model": "gpt-4o",
@@ -349,3 +320,268 @@ async def test_input_from_guardrail_from_file(
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("do_stream", [True, False])
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = get_open_ai_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"',
action="block",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Ask about the capital of Spain
# This should not be blocked by the guardrails from the explorer when we push to explorer
# because the file based guardrails are overridden by the explorer guardrails
spain_request = {
"model": "gpt-4o",
"messages": [{"role": "user", "content": "What is the capital of Spain?"}],
"max_tokens": 100,
}
if not do_stream:
chat_response = client.chat.completions.create(
**spain_request,
stream=False,
)
assert "Madrid" in chat_response.choices[0].message.content
else:
chat_response = client.chat.completions.create(
**spain_request,
stream=True,
)
merged_content = ""
for chunk in chat_response:
if chunk.choices[0].delta.content:
merged_content += chunk.choices[0].delta.content
assert "Madrid" in merged_content
# Ask about Shrek
# This should be blocked by the guardrails from the explorer
user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details."
shrek_request = {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": user_prompt,
}
],
"max_tokens": 100,
}
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.chat.completions.create(
**shrek_request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
)
# Only the block guardrail should be triggered here
assert "ogre detected in response" in str(exc_info.value)
assert "Fiona detected in response" not in str(exc_info.value)
else:
with pytest.raises(APIError) as exc_info:
chat_response = client.chat.completions.create(
**shrek_request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 2
trace_id = traces[1]["id"]
# Fetch the second trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2
assert trace["messages"][0] == {
"role": "user",
"content": user_prompt,
}
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 2
assert (
annotations[0]["content"] == "ogre detected in response"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
)
assert (
annotations[1]["content"] == "Fiona detected in response"
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, is_block_action",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_preguardrailing_with_guardrails_from_explorer(
explorer_api_url, gateway_url, do_stream, is_block_action
):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = get_open_ai_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
action="block" if is_block_action else "log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
user_prompt = "Tell me a one sentence pun."
request = {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": user_prompt,
}
],
"max_tokens": 100,
}
if is_block_action:
if do_stream:
with pytest.raises(APIError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
else:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
_ = client.chat.completions.create(
**request,
stream=True,
)
else:
_ = client.chat.completions.create(
**request,
stream=False,
)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 1 if is_block_action else 2
assert trace["messages"][0] == {
"role": "user",
"content": user_prompt,
}
if not is_block_action:
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
if is_block_action
else "log"
)
@@ -1,4 +1,4 @@
"""Test the guardrails from file with the OpenAI route."""
"""Test the guardrails from header with the OpenAI route."""
import os
import sys
@@ -136,9 +136,7 @@ raise "Users must not mention the magic phrase 'Abracadabra'" if:
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_invalid_guardrail_in_header(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
async def test_invalid_guardrail_in_header(gateway_url, do_stream, push_to_explorer):
"""Test the message content guardrail."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
@@ -178,7 +176,8 @@ raise "Users must not mention the magic phrase 'Abracadabra'" if:
stream=False,
)
assert "Gateway: Guardrails check failed" in str(
print(exc_info.value.message, flush=True)
assert "Failed to create policy from policy source." in str(
exc_info.value
), "guardrails check fails because of an invalid guardrailing rule"
assert "illegal statement" in str(
@@ -9,10 +9,10 @@ import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_open_ai_client
import pytest
import requests
from httpx import Client
from openai import OpenAI
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -28,17 +28,7 @@ async def test_chat_completion_with_tool_call_without_streaming(
without streaming.
"""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
chat_response = client.chat.completions.create(
model="gpt-4o",
@@ -146,17 +136,7 @@ async def test_chat_completion_with_tool_call_with_streaming(
while streaming.
"""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
chat_response = client.chat.completions.create(
model="gpt-4o",
@@ -11,6 +11,8 @@ from unittest.mock import patch
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_open_ai_client
import pytest
import requests
from httpx import Client
@@ -30,17 +32,7 @@ async def test_chat_completion(
):
"""Test the chat completions gateway calls without tool calling."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
chat_response = client.chat.completions.create(
model="gpt-4o",
@@ -103,17 +95,8 @@ async def test_chat_completion_with_image(
):
"""Test the chat completions gateway works with image."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png"
with image_path.open("rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
@@ -189,9 +172,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
openai_api_key = os.getenv("OPENAI_API_KEY")
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
with patch.dict(
os.environ,
{"OPENAI_API_KEY": openai_api_key + ";invariant-auth=<not needed for test>"},
{"OPENAI_API_KEY": openai_api_key + invariant_key_suffix},
):
client = OpenAI(
http_client=Client(),
@@ -252,7 +236,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
@@ -1,15 +1,15 @@
# For Anthropic and OpenAI tests:
# For Anthropic and OpenAI tests
raise "Madrid detected in the response" if:
(msg: Message)
"Madrid" in msg.content
# For Gemini tests:
# For Gemini tests
raise "Dublin detected in the response" if:
(msg: Message)
(chunk: str) in text(msg.content)
"Dublin" in chunk
# For all provider tests:
# For all provider tests
raise "get_capital is called with Germany as argument" if:
(call: ToolCall)
call is tool:get_capital
+105
View File
@@ -0,0 +1,105 @@
"""Common utilities for integration tests."""
import os
import uuid
from typing import Any, Dict, Literal, Optional
from httpx import AsyncClient, Client
from openai import OpenAI
from google import genai
from anthropic import Anthropic
def get_open_ai_client(
gateway_url: str, push_to_explorer: bool, dataset_name: str
) -> OpenAI:
"""Create an OpenAI client for integration tests."""
return OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
def get_anthropic_client(
gateway_url: str, push_to_explorer: bool, dataset_name: str
) -> Anthropic:
"""Create an Anthropic client for integration tests."""
return Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
)
def get_gemini_client(
gateway_url: str, push_to_explorer: bool, dataset_name: str
) -> genai.Client:
"""Create a Gemini client for integration tests."""
return genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
},
)
async def create_dataset(
explorer_api_url: str,
invariant_authorization: str,
dataset_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Create a dataset in the Explorer API."""
client = Client(base_url=explorer_api_url)
response = client.post(
"/api/v1/dataset/create",
json={"name": dataset_name if dataset_name else f"test-dataset-{uuid.uuid4()}"},
headers={"Authorization": invariant_authorization},
timeout=5,
)
if response.status_code != 200:
raise ValueError(
f"Failed to create dataset: {response.status_code}, {response.text}"
)
return response.json()
async def add_guardrail_to_dataset(
explorer_api_url: str,
dataset_id: str,
policy: str,
action: Literal["block", "log"],
invariant_authorization: str,
) -> Dict[str, Any]:
"""Add a guardrail to a dataset."""
client = Client(base_url=explorer_api_url)
response = client.post(
f"/api/v1/dataset/{dataset_id}/policy",
json={
"action": action,
"policy": policy,
"name": f"test-guardrail-{uuid.uuid4()}",
},
headers={"Authorization": invariant_authorization},
timeout=5,
)
if response.status_code != 200:
raise ValueError(
f"Failed to add guardrail: {response.status_code}, {response.text}"
)
return response.json()