Fetch guardrails from explorer. These have higher precedence than than the guardrails from file.

This commit is contained in:
Hemang
2025-04-01 14:16:05 +02:00
committed by Hemang Sarkar
parent f45a973f51
commit 050ec1ba58
17 changed files with 477 additions and 136 deletions
+2 -2
View File
@@ -11,7 +11,7 @@ class GatewayConfig:
"""Common configurations for the Gateway Server."""
def __init__(self):
self.guardrails = self._load_guardrails_from_file()
self.guardrails_from_file = self._load_guardrails_from_file()
def _load_guardrails_from_file(self) -> str:
"""
@@ -48,7 +48,7 @@ class GatewayConfig:
raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e
def __repr__(self) -> str:
return f"GatewayConfig(guardrails={repr(self.guardrails)})"
return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})"
class GatewayConfigManager:
+31
View File
@@ -0,0 +1,31 @@
"""Common guardrails data class."""
from enum import Enum
from typing import List
from dataclasses import dataclass
class GuardrailAction(str, Enum):
"""Enum representing the action to be taken for guardrail rules."""
BLOCK = "block"
LOG = "log"
@dataclass(frozen=True)
class Guardrail:
"""Represents a single guardrail rule."""
id: str
name: str
content: str
action: GuardrailAction
@dataclass(frozen=True)
class DatasetGuardrails:
"""Grouped guardrail rules separated by their action."""
blocking_guardrails: List[Guardrail]
logging_guardrails: List[Guardrail]
+92
View File
@@ -0,0 +1,92 @@
"""Common Request context data class."""
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from common.config_manager import GatewayConfig
from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction
@dataclass(frozen=True)
class RequestContext:
"""Structured context for a request. Must be created via `RequestContext.create()`."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
invariant_authorization: Optional[str] = None
dataset_guardrails: Optional[DatasetGuardrails] = None
config: Dict[str, Any] = None
_created_via_factory: bool = field(
default=False, init=True, repr=False, compare=False
)
def __post_init__(self):
if not self._created_via_factory:
raise RuntimeError(
"RequestContext must be created using RequestContext.create()"
)
@classmethod
def create(
cls,
request_json: Dict[str, Any],
dataset_name: Optional[str] = None,
invariant_authorization: Optional[str] = None,
dataset_guardrails: Optional[DatasetGuardrails] = None,
config: Optional[GatewayConfig] = None,
) -> "RequestContext":
"""Creates a new RequestContext instance, applying default guardrails if needed."""
# Convert GatewayConfig to a basic dict, excluding guardrails_from_file
context_config = {
key: value
for key, value in (config.__dict__.items() if config else {})
if key != "guardrails_from_file"
}
# If no guardrails are configured for the dataset on Explorer,
# and the config specifies guardrails_from_file, use that.
guardrails = dataset_guardrails
if (
(
not dataset_guardrails
or (
not dataset_guardrails.blocking_guardrails
and not dataset_guardrails.logging_guardrails
)
)
and config
and config.guardrails_from_file
):
# TODO: Support logging guardrails via file.
guardrails = DatasetGuardrails(
blocking_guardrails=[
Guardrail(
id="default",
name="default",
content=config.guardrails_from_file,
action=GuardrailAction.BLOCK,
)
],
logging_guardrails=[],
)
return cls(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
dataset_guardrails=guardrails,
config=context_config,
_created_via_factory=True,
)
def __repr__(self) -> str:
return (
f"RequestContext("
f"request_json={self.request_json}, "
f"dataset_name={self.dataset_name}, "
f"invariant_authorization={self.invariant_authorization}, "
f"dataset_guardrails={self.dataset_guardrails}, "
f"config={self.config})"
)
-16
View File
@@ -1,16 +0,0 @@
"""Common Request context data class."""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from common.config_manager import GatewayConfig
@dataclass(frozen=True)
class RequestContextData:
"""Request context data class."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
invariant_authorization: Optional[str] = None
config: Optional[GatewayConfig] = None
+79
View File
@@ -3,10 +3,13 @@
import os
from typing import Any, Dict, List
from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
from invariant_sdk.types.annotations import AnnotationCreate
import httpx
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -91,3 +94,79 @@ async def push_trace(
except Exception as e:
print(f"Failed to push trace: {e}")
return {"error": str(e)}
async def fetch_guardrails_from_explorer(
dataset_name: str, invariant_authorization: str
) -> DatasetGuardrails:
"""Get the guardrails for the dataset.
Returns:
DatasetGuardrails: The guardrails for the dataset grouped by their action.
"""
# TODO: Implement a single API in explorer backend which can return
# dataset details without requiring a username.
client = httpx.AsyncClient(
base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
headers={
"Invariant-Authorization": invariant_authorization,
},
)
# Get the user details.
user_info_response = await client.get("/api/v1/user/info")
if user_info_response.status_code != 200:
raise ValueError(
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
)
user_details = user_info_response.json()
username = user_details["username"]
# Get the dataset policies.
policies_response = await client.get(
f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy"
)
if policies_response.status_code != 200:
if policies_response.status_code == 404:
# If the dataset does not exist, return empty guardrails.
return DatasetGuardrails(
blocking_guardrails=[],
logging_guardrails=[],
)
raise ValueError(
f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}"
)
policies_details = policies_response.json()
guardrails = policies_details.get("policies", [])
blocking_guardrails = []
logging_guardrails = []
for g in guardrails:
action = g["action"]
if not g["enabled"]:
# Skip guardrails that are not enabled.
continue
if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG):
print("[Warning] Skipping unknown guardrail action: ", action)
continue
guardrail = Guardrail(
id=g["id"],
name=g["name"],
content=g["content"],
action=GuardrailAction(action),
)
if action == GuardrailAction.BLOCK:
blocking_guardrails.append(guardrail)
else:
logging_guardrails.append(guardrail)
return DatasetGuardrails(
blocking_guardrails=blocking_guardrails,
logging_guardrails=logging_guardrails,
)
+55 -13
View File
@@ -1,13 +1,15 @@
"""Utility functions for Guardrails execution."""
import asyncio
import json
import os
import time
from typing import Any, Dict, List
from functools import wraps
import httpx
from common.request_context_data import RequestContextData
from common.guardrails import Guardrail
from common.request_context import RequestContext
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -81,21 +83,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None:
result.raise_for_status()
async def preload_guardrails(context: "RequestContextData") -> None:
async def preload_guardrails(context: "RequestContext") -> None:
"""
Preloads the guardrails for faster checking later.
Args:
context: RequestContextData object.
context: RequestContext object.
"""
if not context.config or not context.config.guardrails:
if not context.dataset_guardrails:
return
try:
task = asyncio.create_task(
_preload(context.config.guardrails, context.invariant_authorization)
)
asyncio.shield(task)
# Move these calls to a batch preload/validate API.
for blocking_guardrail in context.dataset_guardrails.blocking_guardrails:
task = asyncio.create_task(
_preload(blocking_guardrail.content, context.invariant_authorization)
)
asyncio.shield(task)
for logging_guadrail in context.dataset_guardrails.logging_guardrails:
task = asyncio.create_task(
_preload(logging_guadrail.content, context.invariant_authorization)
)
asyncio.shield(task)
except Exception as e:
print(f"Error scheduling preload_guardrails task: {e}")
@@ -322,14 +331,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
async def check_guardrails(
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
messages: List[Dict[str, Any]],
guardrails: List[Guardrail],
invariant_authorization: str,
) -> Dict[str, Any]:
"""
Checks guardrails on the list of messages.
This calls the batch check API of the Guardrails service.
Args:
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (str): The guardrails to check against.
guardrails (List[Guardrail]): The guardrails to check against.
invariant_authorization (str): Value of the
invariant-authorization header.
@@ -339,9 +351,34 @@ async def check_guardrails(
async with httpx.AsyncClient() as client:
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
try:
print(
"Hello there this is the request to guardrails: ",
json.dumps(
{
"messages": messages,
"policies": [g.content for g in guardrails],
},
indent=2,
),
flush=True,
)
print(
"Hello there this is the request to guardrails: ",
json.dumps(
{
"Authorization": invariant_authorization,
"Accept": "application/json",
},
indent=2,
),
flush=True,
)
result = await client.post(
f"{url}/api/v1/policy/check",
json={"messages": messages, "policy": guardrails},
f"{url}/api/v1/policy/check/batch",
json={
"messages": messages,
"policies": [g.content for g in guardrails],
},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
@@ -352,7 +389,12 @@ async def check_guardrails(
f"Guardrails check failed: {result.status_code} - {result.text}"
)
print(f"Guardrail check response: {result.json()}")
return result.json()
guardrails_result = result.json()
aggregated_errors = {"errors": []}
for res in guardrails_result.get("result", []):
aggregated_errors["errors"].extend(res.get("errors", []))
return aggregated_errors
except Exception as e:
print(f"Failed to verify guardrails: {e}")
return {"error": str(e)}
+68 -29
View File
@@ -14,11 +14,16 @@ from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from common.request_context_data import RequestContextData
from common.guardrails import GuardrailAction
from common.request_context import RequestContext
from converters.anthropic_to_invariant import (
convert_anthropic_to_invariant_message_format,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
push_trace,
)
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
@@ -83,10 +88,17 @@ async def anthropic_v1_messages_gateway(
data=request_body,
)
context = RequestContextData(
dataset_guardrails = None
if dataset_name:
# Get the guardrails for the dataset from explorer.
dataset_guardrails = await fetch_guardrails_from_explorer(
dataset_name, invariant_authorization
)
context = RequestContext.create(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
dataset_guardrails=dataset_guardrails,
config=config,
)
asyncio.create_task(preload_guardrails(context))
@@ -97,7 +109,7 @@ async def anthropic_v1_messages_gateway(
def create_metadata(
context: RequestContextData, response_json: dict[str, Any]
context: RequestContext, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Creates metadata for the trace"""
metadata = {k: v for k, v in context.request_json.items() if k != "messages"}
@@ -108,7 +120,7 @@ def create_metadata(
def combine_request_and_response_messages(
context: RequestContextData, json_response: dict[str, Any]
context: RequestContext, json_response: dict[str, Any]
):
"""Combine the request and response messages"""
messages = []
@@ -123,23 +135,32 @@ def combine_request_and_response_messages(
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any]
context: RequestContext, action: GuardrailAction, json_response: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
# Determine which guardrails to apply based on the action
guardrails = (
context.dataset_guardrails.logging_guardrails
if action == GuardrailAction.LOG
else context.dataset_guardrails.blocking_guardrails
)
if not guardrails:
return {}
messages = combine_request_and_response_messages(context, json_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=converted_messages,
guardrails=context.config.guardrails,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
async def push_to_explorer(
context: RequestContextData,
context: RequestContext,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
@@ -163,14 +184,16 @@ async def push_to_explorer(
class InstrumentedAnthropicResponse(InstrumentedResponse):
"""Instrumented response for Anthropic API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
@@ -184,9 +207,9 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
if self.context.dataset_guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
self.context, action=GuardrailAction.BLOCK, json_response={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -264,10 +287,17 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
assert self.json_response is not None, "json_response is None"
assert self.response_string is not None, "response_string is None"
if self.context.config and self.context.config.guardrails:
if self.context.dataset_guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
self.context,
action=GuardrailAction.BLOCK,
json_response=self.json_response,
)
print(
"Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ",
guardrails_execution_result,
flush=True,
)
if guardrails_execution_result.get("errors", []):
guardrail_response_string = json.dumps(
@@ -306,7 +336,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
async def handle_non_streaming_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
) -> Response:
@@ -320,17 +350,19 @@ async def handle_non_streaming_response(
return await response.instrumented_request()
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
"""Instrumented streaming response for Anthropic API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
@@ -342,9 +374,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
if self.context.dataset_guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -392,6 +426,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
yield chunk
async def on_chunk(self, chunk):
"""Process the chunk and update the merged_response"""
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
return
@@ -400,14 +435,17 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
process_chunk(decoded_chunk, self.merged_response)
# on last stream chunk, run output guardrails
if (
"event: message_stop" in decoded_chunk
and self.context.config
and self.context.config.guardrails
):
if "event: message_stop" in decoded_chunk and self.context.dataset_guardrails:
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
)
print(
"Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ",
self.guardrails_execution_result,
flush=True,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -420,7 +458,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
}
)
# yield an extra error chunk (without preventing the original chunk to go through after,
# yield an extra error chunk (without preventing the original chunk
# to go through after,
# so client gets the proper message_stop event still)
return ExtraItem(
value=f"event: error\ndata: {error_chunk}\n\n".encode()
@@ -440,12 +479,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
async def handle_streaming_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
response = InstrumentedAnthropicStreamingResposne(
response = InstrumentedAnthropicStreamingResponse(
context=context,
client=client,
anthropic_request=anthropic_request,
+62 -24
View File
@@ -14,9 +14,14 @@ from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from common.request_context_data import RequestContextData
from common.guardrails import GuardrailAction
from common.request_context import RequestContext
from converters.gemini_to_invariant import convert_request, convert_response
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
push_trace,
)
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
@@ -76,10 +81,17 @@ async def gemini_generate_content_gateway(
headers=headers,
)
context = RequestContextData(
dataset_guardrails = None
if dataset_name:
# Get the guardrails for the dataset
dataset_guardrails = await fetch_guardrails_from_explorer(
dataset_name, invariant_authorization
)
context = RequestContext.create(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
dataset_guardrails=dataset_guardrails,
config=config,
)
asyncio.create_task(preload_guardrails(context))
@@ -98,16 +110,18 @@ async def gemini_generate_content_gateway(
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
"""Instrumented streaming response for Gemini API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
@@ -124,6 +138,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
location: Literal["request", "response"],
guardrails_execution_result: dict[str, Any],
) -> dict:
"""Create a refusal response for the given request or response"""
return {
"candidates": [
{
@@ -157,10 +172,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
"""
Check guardrails in a pipelined fashion, before processing the first chunk
(for input guardrailing).
"""
if self.context.dataset_guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
self.context, action=GuardrailAction.BLOCK, response_json={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -184,6 +202,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
)
async def event_generator(self):
"""Event generator for streaming responses"""
response = await self.client.send(self.gemini_request, stream=True)
if response.status_code != 200:
@@ -199,6 +218,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
yield chunk
async def on_chunk(self, chunk):
"""Processes each chunk of the streaming response"""
chunk_text = chunk.decode().strip()
if not chunk_text:
return
@@ -210,12 +230,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
if (
self.merged_response.get("candidates", [])
and self.merged_response.get("candidates")[0].get("finishReason", "")
and self.context.config
and self.context.config.guardrails
and self.context.dataset_guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -254,7 +275,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
async def stream_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
) -> Response:
@@ -332,7 +353,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) ->
def create_metadata(
context: RequestContextData, response_json: dict[str, Any]
context: RequestContext, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Creates metadata for the trace"""
metadata = {
@@ -352,23 +373,32 @@ def create_metadata(
async def get_guardrails_check_result(
context: RequestContextData, response_json: dict[str, Any]
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
# Determine which guardrails to apply based on the action
guardrails = (
context.dataset_guardrails.logging_guardrails
if action == GuardrailAction.LOG
else context.dataset_guardrails.blocking_guardrails
)
if not guardrails:
return {}
converted_requests = convert_request(context.request_json)
converted_responses = convert_response(response_json)
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=converted_requests + converted_responses,
guardrails=context.config.guardrails,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
async def push_to_explorer(
context: RequestContextData,
context: RequestContext,
response_json: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
@@ -391,16 +421,18 @@ async def push_to_explorer(
class InstrumentedGeminiResponse(InstrumentedResponse):
"""Instrumented response for Gemini API"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
@@ -412,10 +444,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
self.guardrails_execution_result: Optional[dict[str, Any]] = None
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
"""
Check guardrails in a pipelined fashion, before processing the first chunk
(for input guardrailing).
"""
if self.context.dataset_guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
self.context, action=GuardrailAction.BLOCK, response_json={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -463,6 +498,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
)
async def request(self):
"""Makes the request to the Gemini API and return the response"""
self.response = await self.client.send(self.gemini_request)
response_string = self.response.text
@@ -492,10 +528,12 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
if self.context.config and self.context.config.guardrails:
if self.context.dataset_guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.response_json
self.context,
action=GuardrailAction.BLOCK,
response_json=self.response_json,
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
@@ -539,7 +577,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
async def handle_non_streaming_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
) -> Response:
+66 -35
View File
@@ -14,8 +14,13 @@ from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from common.request_context_data import RequestContextData
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from common.guardrails import GuardrailAction
from common.request_context import RequestContext
from integrations.explorer import (
create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer,
push_trace,
)
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
@@ -72,10 +77,17 @@ async def openai_chat_completions_gateway(
headers=headers,
)
context = RequestContextData(
dataset_guardrails = None
if dataset_name:
# Get the guardrails for the dataset
dataset_guardrails = await fetch_guardrails_from_explorer(
dataset_name, invariant_authorization
)
context = RequestContext.create(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
dataset_guardrails=dataset_guardrails,
config=config,
)
asyncio.create_task(preload_guardrails(context))
@@ -92,19 +104,20 @@ async def openai_chat_completions_gateway(
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
"""
Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
Does a streaming OpenAI completion request at the core, but also checks guardrails
before (concurrent) and after the request.
"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
@@ -131,10 +144,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
self.tool_call_mapping_by_index = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
"""
Check guardrails in a pipelined fashion, before processing the first chunk
(for input guardrailing).
"""
if self.context.dataset_guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -164,6 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
)
async def on_chunk(self, chunk):
"""Processes each chunk of the stream and checks guardrails at the end of the stream"""
# process and check each chunk
chunk_text = chunk.decode().strip()
if not chunk_text:
@@ -179,14 +198,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
)
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
if (
"data: [DONE]" in chunk_text
and self.context.config
and self.context.config.guardrails
):
if "data: [DONE]" in chunk_text and self.context.dataset_guardrails:
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -214,10 +231,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
)
async def event_generator(self):
"""
Actual OpenAI stream response.
"""
"""Actual OpenAI stream response."""
response = await self.client.send(self.open_ai_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
@@ -234,7 +248,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
async def handle_stream_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
@@ -389,7 +403,7 @@ def update_existing_choice_with_delta(
def create_metadata(
context: RequestContextData, merged_response: dict[str, Any]
context: RequestContext, merged_response: dict[str, Any]
) -> dict[str, Any]:
"""Creates metadata for the trace"""
metadata = {
@@ -409,7 +423,7 @@ def create_metadata(
async def push_to_explorer(
context: RequestContextData,
context: RequestContext,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
@@ -437,18 +451,28 @@ async def push_to_explorer(
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any] | None = None
context: RequestContext,
action: GuardrailAction,
json_response: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = list(context.request_json.get("messages", []))
# Determine which guardrails to apply based on the action
guardrails = (
context.dataset_guardrails.logging_guardrails
if action == GuardrailAction.LOG
else context.dataset_guardrails.blocking_guardrails
)
if not guardrails:
return {}
messages = list(context.request_json.get("messages", []))
if json_response is not None:
messages += [choice["message"] for choice in json_response.get("choices", [])]
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=messages,
guardrails=context.config.guardrails,
guardrails=guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
@@ -456,19 +480,20 @@ async def get_guardrails_check_result(
class InstrumentedOpenAIResponse(InstrumentedResponse):
"""
Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
Does an OpenAI completion request at the core, but also checks guardrails
before (concurrent) and after the request.
"""
def __init__(
self,
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.context: RequestContext = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
@@ -480,11 +505,14 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
self.guardrails_execution_result: Optional[dict] = None
async def on_start(self):
"""Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)"""
if self.context.config and self.context.config.guardrails:
"""
Checks guardrails in a pipelined fashion, before processing
the first chunk (for input guardrailing)
"""
if self.context.dataset_guardrails:
# block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context
self.context, action=GuardrailAction.BLOCK
)
if self.guardrails_execution_result.get("errors", []):
# Push annotated trace to the explorer - don't block on its response
@@ -542,7 +570,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
async def on_end(self):
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
# these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed)
# these two request outputs are guaranteed to be available by the time we reach
# this point (after self.request() was executed)
# nevertheless, we check for them to avoid any potential issues
assert (
self.response is not None
@@ -555,10 +584,12 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
response_code = self.response.status_code
# if we have guardrails, check the response
if self.context.config and self.context.config.guardrails:
if self.context.dataset_guardrails:
# run guardrails again, this time on request + response
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
self.context,
action=GuardrailAction.BLOCK,
json_response=self.json_response,
)
if self.guardrails_execution_result.get("errors", []):
response_string = json.dumps(
@@ -601,7 +632,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
async def handle_non_stream_response(
context: RequestContextData,
context: RequestContext,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
+5
View File
@@ -93,6 +93,11 @@ integration_tests() {
fi
echo "File successfully downloaded: $FILE"
if [[ -z "$INVARIANT_API_KEY" ]]; then
echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests."
exit 1
fi
TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py"
if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then
if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then
@@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
"""Test the Anthropic gateway with Invariant key in the Anthropic key"""
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
with patch.dict(
os.environ,
{
"ANTHROPIC_API_KEY": anthropic_api_key
+ ";invariant-auth=<not needed for test>"
},
{"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix},
):
client = anthropic.Anthropic(
http_client=Client(),
@@ -26,7 +26,7 @@ class WeatherAgent:
def __init__(self, gateway_url, push_to_explorer):
self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
self.client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
@@ -26,7 +26,7 @@ async def test_response_without_tool_call(
):
"""Test the Anthropic gateway without tool calling."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
client = anthropic.Anthropic(
http_client=Client(
@@ -91,7 +91,7 @@ async def test_streaming_response_without_tool_call(
):
"""Test the Anthropic gateway without tool calling."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
client = anthropic.Anthropic(
http_client=Client(
@@ -151,7 +151,7 @@ async def test_generate_content_with_tool_call(
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"invariant-authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
},
)
@@ -36,7 +36,7 @@ async def test_generate_content(
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"invariant-authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
},
)
@@ -123,7 +123,7 @@ async def test_generate_content_with_image(
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
"headers": {
"invariant-authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
},
)
@@ -181,9 +181,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
"""Test the generate content gateway calls with the Invariant API Key in the Gemini Key header."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
gemini_api_key = os.getenv("GEMINI_API_KEY")
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
with patch.dict(
os.environ,
{"GEMINI_API_KEY": gemini_api_key + ";invariant-auth=<not needed for test>"},
{"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix},
):
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
@@ -32,7 +32,7 @@ async def test_chat_completion_with_tool_call_without_streaming(
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
@@ -150,7 +150,7 @@ async def test_chat_completion_with_tool_call_with_streaming(
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
@@ -34,7 +34,7 @@ async def test_chat_completion(
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
@@ -107,7 +107,7 @@ async def test_chat_completion_with_image(
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
@@ -189,9 +189,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
openai_api_key = os.getenv("OPENAI_API_KEY")
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
with patch.dict(
os.environ,
{"OPENAI_API_KEY": openai_api_key + ";invariant-auth=<not needed for test>"},
{"OPENAI_API_KEY": openai_api_key + invariant_key_suffix},
):
client = OpenAI(
http_client=Client(),
@@ -252,7 +253,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",