Pipelined Guardrails (#32)

* initial draft: pipelined guardrails

* documentation on stream instrumentation

* more comments

* fix: return earlier

* non-streaming case

* handle non-streaming case

* fix more cases

* simplify request instrumentation

* improve comments

* fix import issues

* extend tests for input guardrailing

* anthropic integration of pipelined and pre-guardrailing

* fix gemini streamed refusal
This commit is contained in:
Luca Beurer-Kellner
2025-03-31 14:13:58 +02:00
committed by GitHub
parent 4671c8b67e
commit 7c0bb957fb
13 changed files with 1659 additions and 495 deletions

2
.env
View File

@@ -3,4 +3,4 @@
# If you want to push to a local instance of explorer, then specify the app-api docker container name like:
# http://<app-api-docker-container-name>:8000 to push to the local explorer instance.
INVARIANT_API_URL=https://explorer.invariantlabs.ai
GUADRAILS_API_URL=https://guardrail.invariantnet.com
GUADRAILS_API_URL=https://explorer.invariantlabs.ai

7
example_policy.gr Normal file
View File

@@ -0,0 +1,7 @@
from invariant.detectors import prompt_injection
raise "Don't say 'Hello'" if:
(msg: Message)
msg.role == "user"
prompt_injection(msg.content)
# "Hello" in msg.content

View File

@@ -4,8 +4,6 @@ import asyncio
import os
import threading
from integrations.guardails import _preload
from httpx import HTTPStatusError
@@ -20,6 +18,8 @@ class GatewayConfig:
Loads the guardrails from the file specified in GUARDRAILS_FILE_PATH.
Returns the guardrails file content as a string.
"""
from integrations.guardrails import _preload
guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "")
if not guardrails_file:
print("[warning: GUARDRAILS_FILE_PATH is not set. Using empty guardrails]")

View File

@@ -1,132 +0,0 @@
"""Utility functions for Guardrails execution."""
import asyncio
import os
import time
from typing import Any, Dict, List
from functools import wraps
import httpx
DEFAULT_API_URL = "https://guardrail.invariantnet.com"
# Timestamps of last API calls per guardrails string
_guardrails_cache = {}
# Locks per guardrails string
_guardrails_locks = {}
def rate_limit(expiration_time: int = 3600):
"""
Decorator to limit API calls to once per expiration_time seconds
per unique guardrails string.
Args:
expiration_time (int): Time in seconds to cache the guardrails.
"""
def decorator(func):
@wraps(func)
async def wrapper(guardrails: str, *args, **kwargs):
now = time.time()
# Get or create a per-guardrail lock
if guardrails not in _guardrails_locks:
_guardrails_locks[guardrails] = asyncio.Lock()
guardrail_lock = _guardrails_locks[guardrails]
async with guardrail_lock:
last_called = _guardrails_cache.get(guardrails)
if last_called and (now - last_called < expiration_time):
# Skipping API call: Guardrails '{guardrails}' already
# preloaded within expiration_time
return
# Update cache timestamp
_guardrails_cache[guardrails] = now
try:
await func(guardrails, *args, **kwargs)
finally:
_guardrails_locks.pop(guardrails, None)
return wrapper
return decorator
@rate_limit(3600) # Don't preload the same guardrails string more than once per hour
async def _preload(guardrails: str, invariant_authorization: str) -> None:
"""
Calls the Guardrails API to preload the provided policy for faster checking later.
Args:
guardrails (str): The guardrails to preload.
invariant_authorization (str): Value of the
invariant-authorization header.
"""
async with httpx.AsyncClient() as client:
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
result = await client.post(
f"{url}/api/v1/policy/load",
json={"policy": guardrails},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
},
)
result.raise_for_status()
async def preload_guardrails(context: "RequestContextData") -> None:
"""
Preloads the guardrails for faster checking later.
Args:
context: RequestContextData object.
"""
if not context.config or not context.config.guardrails:
return
try:
task = asyncio.create_task(
_preload(context.config.guardrails, context.invariant_authorization)
)
asyncio.shield(task)
except Exception as e:
print(f"Error scheduling preload_guardrails task: {e}")
async def check_guardrails(
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
) -> Dict[str, Any]:
"""
Checks guardrails on the list of messages.
Args:
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (str): The guardrails to check against.
invariant_authorization (str): Value of the
invariant-authorization header.
Returns:
Dict: Response containing guardrail check results.
"""
async with httpx.AsyncClient() as client:
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
try:
result = await client.post(
f"{url}/api/v1/policy/check",
json={"messages": messages, "policy": guardrails},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
},
)
print(f"Guardrail check response: {result.json()}")
return result.json()
except Exception as e:
print(f"Failed to verify guardrails: {e}")
return {"error": str(e)}

View File

@@ -0,0 +1,358 @@
"""Utility functions for Guardrails execution."""
import asyncio
import os
import time
from typing import Any, Dict, List
from functools import wraps
import httpx
from common.request_context_data import RequestContextData
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
# Timestamps of last API calls per guardrails string
_guardrails_cache = {}
# Locks per guardrails string
_guardrails_locks = {}
def rate_limit(expiration_time: int = 3600):
"""
Decorator to limit API calls to once per expiration_time seconds
per unique guardrails string.
Args:
expiration_time (int): Time in seconds to cache the guardrails.
"""
def decorator(func):
@wraps(func)
async def wrapper(guardrails: str, *args, **kwargs):
now = time.time()
# Get or create a per-guardrail lock
if guardrails not in _guardrails_locks:
_guardrails_locks[guardrails] = asyncio.Lock()
guardrail_lock = _guardrails_locks[guardrails]
async with guardrail_lock:
last_called = _guardrails_cache.get(guardrails)
if last_called and (now - last_called < expiration_time):
# Skipping API call: Guardrails '{guardrails}' already
# preloaded within expiration_time
return
# Update cache timestamp
_guardrails_cache[guardrails] = now
try:
await func(guardrails, *args, **kwargs)
finally:
_guardrails_locks.pop(guardrails, None)
return wrapper
return decorator
@rate_limit(3600) # Don't preload the same guardrails string more than once per hour
async def _preload(guardrails: str, invariant_authorization: str) -> None:
"""
Calls the Guardrails API to preload the provided policy for faster checking later.
Args:
guardrails (str): The guardrails to preload.
invariant_authorization (str): Value of the
invariant-authorization header.
"""
async with httpx.AsyncClient() as client:
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
result = await client.post(
f"{url}/api/v1/policy/load",
json={"policy": guardrails},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
},
)
result.raise_for_status()
async def preload_guardrails(context: "RequestContextData") -> None:
"""
Preloads the guardrails for faster checking later.
Args:
context: RequestContextData object.
"""
if not context.config or not context.config.guardrails:
return
try:
task = asyncio.create_task(
_preload(context.config.guardrails, context.invariant_authorization)
)
asyncio.shield(task)
except Exception as e:
print(f"Error scheduling preload_guardrails task: {e}")
class ExtraItem:
"""
Return this class in a instrumented stream callback, to yield an extra item in the resulting stream.
"""
def __init__(self, value, end_of_stream=False):
self.value = value
self.end_of_stream = end_of_stream
def __str__(self):
return f"<ExtraItem value={self.value} end_of_stream={self.end_of_stream}>"
class Replacement(ExtraItem):
"""
Like ExtraItem, but used to replace the full request result in case of 'InstrumentedResponse'.
"""
def __init__(self, value):
super().__init__(value, end_of_stream=True)
def __str__(self):
return f"<Replacement value={self.value}>"
class InstrumentedStreamingResponse:
def __init__(self):
# request statistics
self.stat_token_times = []
self.stat_before_time = None
self.stat_after_time = None
self.stat_first_item_time = None
async def on_chunk(self, chunk: Any) -> ExtraItem | None:
"""
This called will be called on every chunk (async).
"""
pass
async def on_start(self) -> ExtraItem | None:
"""
Decorator to register a listener for start events.
"""
pass
async def on_end(self) -> ExtraItem | None:
"""
Decorator to register a listener for end events.
"""
pass
async def event_generator(self):
"""
Streams the async iterable and invokes all instrumented hooks.
Args:
async_iterable: An async iterable to stream.
Yields:
The streamed data.
"""
raise NotImplementedError("This method should be implemented in a subclass.")
async def instrumented_event_generator(self):
"""
Streams the async iterable and invokes all instrumented hooks.
Args:
async_iterable: An async iterable to stream.
Yields:
The streamed data.
"""
try:
start = time.time()
# schedule on_start which can be run concurrently
start_task = asyncio.create_task(self.on_start(), name="instrumentor:start")
# create async iterator from async_iterable
aiterable = aiter(self.event_generator())
# [STAT] capture start time of first item
start_first_item_request = time.time()
# waits for first item of the iterable
async def wait_for_first_item():
nonlocal start_first_item_request, aiterable
r = await aiterable.__anext__()
if self.stat_first_item_time is None:
# [STAT] capture time to first item
self.stat_first_item_time = time.time() - start_first_item_request
return r
next_item_task = asyncio.create_task(
wait_for_first_item(), name="instrumentor:next:first"
)
# check if 'start_task' yields an extra item
if extra_item := await start_task:
# yield extra value before any real items
yield extra_item.value
# stop the stream if end_of_stream is True
if extra_item.end_of_stream:
# if first item is already available
if not next_item_task.done():
# cancel the task
next_item_task.cancel()
# [STAT] capture time to first item to be now +0.01
if self.stat_first_item_time is None:
self.stat_first_item_time = (
time.time() - start_first_item_request
) + 0.01
# don't wait for the first item if end_of stream is True
return
# [STAT] capture before time stamp
self.stat_before_time = time.time() - start
while True:
# wait for first item
try:
item = await next_item_task
except StopAsyncIteration:
break
# schedule next item
next_item_task = asyncio.create_task(
aiterable.__anext__(), name="instrumentor:next"
)
# [STAT] capture token time stamp
if len(self.stat_token_times) == 0:
self.stat_token_times.append(time.time() - start)
else:
self.stat_token_times.append(
time.time() - start - sum(self.stat_token_times)
)
if extra_item := await self.on_chunk(item):
yield extra_item.value
# if end_of_stream is True, stop the stream
if extra_item.end_of_stream:
# cancel next task
next_item_task.cancel()
return
# yield item
yield item
# run on_end, before closing the stream (may yield an extra value)
if extra_item := await self.on_end():
# yield extra value before any real items
yield extra_item.value
# we ignore end_of_stream here, because we are already at the end
# [STAT] capture after time stamp
self.stat_after_time = time.time() - start
finally:
# [STAT] end all open intervals if not already closed
if self.stat_after_time is None:
self.stat_before_time = time.time() - start
if self.stat_after_time is None:
self.stat_after_time = 0
if self.stat_first_item_time is None:
self.stat_first_item_time = 0
# print statistics
token_times_5_decimale = str([f"{x:.5f}" for x in self.stat_token_times])
print(
f"[STATS]\n [token times: {token_times_5_decimale} ({len(self.stat_token_times)})]"
)
print(f" [before: {self.stat_before_time:.2f}s] ")
print(f" [time-to-first-item: {self.stat_first_item_time:.2f}s]")
print(
f" [zero-latency: {' TRUE' if self.stat_before_time < self.stat_first_item_time else 'FALSE'}]"
)
print(
f" [extra-latency: {self.stat_before_time - self.stat_first_item_time:.2f}s]"
)
print(f" [after: {self.stat_after_time:.2f}s]")
if len(self.stat_token_times) > 0:
print(
f" [average token time: {sum(self.stat_token_times) / len(self.stat_token_times):.2f}s]"
)
print(f" [total: {time.time() - start:.2f}s]")
class InstrumentedResponse(InstrumentedStreamingResponse):
"""
A class to instrument an async request with hooks for concurrent
pre-processing and post-processing (input and output guardrailing).
"""
async def event_generator(self):
"""
We implement the 'event_generator' as a single item stream,
where the item is the full result of the request.
"""
yield await self.request()
async def request(self):
"""
This method should be implemented in a subclass to perform the actual request.
"""
raise NotImplementedError("This method should be implemented in a subclass.")
async def instrumented_request(self):
"""
Returns the 'Response' object of the request, after applying all instrumented hooks.
"""
results = [r async for r in self.instrumented_event_generator()]
assert len(results) >= 1, "InstrumentedResponse must yield at least one item"
# we return the last item, in case the end callback yields an extra item. Then,
# don't return the actual result but the 'end' result, e.g. for output guardrailing.
return results[-1]
async def check_guardrails(
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
) -> Dict[str, Any]:
"""
Checks guardrails on the list of messages.
Args:
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (str): The guardrails to check against.
invariant_authorization (str): Value of the
invariant-authorization header.
Returns:
Dict: Response containing guardrail check results.
"""
async with httpx.AsyncClient() as client:
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
try:
result = await client.post(
f"{url}/api/v1/policy/check",
json={"messages": messages, "policy": guardrails},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
},
)
if not result.is_success:
raise Exception(
f"Guardrails check failed: {result.status_code} - {result.text}"
)
print(f"Guardrail check response: {result.json()}")
return result.json()
except Exception as e:
print(f"Failed to verify guardrails: {e}")
return {"error": str(e)}

View File

@@ -5,6 +5,7 @@ import json
from typing import Any, Optional
import httpx
from regex import R
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
@@ -18,7 +19,14 @@ from converters.anthropic_to_invariant import (
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from integrations.guardails import check_guardrails, preload_guardrails
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
InstrumentedStreamingResponse,
Replacement,
check_guardrails,
preload_guardrails,
)
gateway = APIRouter()
@@ -85,8 +93,7 @@ async def anthropic_v1_messages_gateway(
if request_json.get("stream"):
return await handle_streaming_response(context, client, anthropic_request)
response = await client.send(anthropic_request)
return await handle_non_streaming_response(context, response)
return await handle_non_streaming_response(context, client, anthropic_request)
def create_metadata(
@@ -110,7 +117,8 @@ def combine_request_and_response_messages(
{"role": "system", "content": context.request_json.get("system")}
)
messages.extend(context.request_json.get("messages", []))
messages.append(json_response)
if len(json_response) > 0:
messages.append(json_response)
return messages
@@ -154,56 +162,282 @@ async def push_to_explorer(
)
class InstrumentedAnthropicResponse(InstrumentedResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
# response data
self.response: Optional[httpx.Response] = None
self.response_string: Optional[str] = None
self.json_response: Optional[dict[str, Any]] = None
# guardrailing response (if any)
self.guardrails_execution_result = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# if we find something, we prevent the request from going through
# and return an error instead
return Replacement(
Response(
content=error_chunk,
status_code=400,
media_type="application/json",
headers={"content-type": "application/json"},
)
)
async def request(self):
self.response = await self.client.send(self.anthropic_request)
try:
json_response = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}",
) from e
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
self.json_response = json_response
self.response_string = json.dumps(json_response)
return self._make_response(
content=self.response_string,
status_code=self.response.status_code,
)
def _make_response(self, content: str, status_code: int):
"""Creates a new Response object with the correct headers and content"""
assert self.response is not None, "response is None"
updated_headers = self.response.headers.copy()
updated_headers.pop("Content-Length", None)
return Response(
content=content,
status_code=status_code,
media_type="application/json",
headers=dict(updated_headers),
)
async def on_end(self):
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
# ensure the response data is available
assert self.response is not None, "response is None"
assert self.json_response is not None, "json_response is None"
assert self.response_string is not None, "response_string is None"
if self.context.config and self.context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
)
if guardrails_execution_result.get("errors", []):
guardrail_response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
)
# push to explorer (if configured)
if self.context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
guardrails_execution_result,
)
)
return Replacement(
self._make_response(
content=guardrail_response_string,
status_code=400,
)
)
# push to explorer (if configured)
if self.context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context, self.json_response, guardrails_execution_result
)
)
async def handle_non_streaming_response(
context: RequestContextData,
response: httpx.Response,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
) -> Response:
"""Handles non-streaming Anthropic responses"""
try:
json_response = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
guardrails_execution_result = {}
response_string = json.dumps(json_response)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, json_response
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
)
response_code = 400
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(context, json_response, guardrails_execution_result)
)
updated_headers = response.headers.copy()
updated_headers.pop("Content-Length", None)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(updated_headers),
response = InstrumentedAnthropicResponse(
context=context,
client=client,
anthropic_request=anthropic_request,
)
return await response.instrumented_request()
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.anthropic_request: httpx.Request = anthropic_request
# response data
self.merged_response = {}
# guardrailing response (if any)
self.guardrails_execution_result = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"event: error\ndata: {error_chunk}\n\n".encode(),
end_of_stream=True,
)
async def event_generator(self):
"""Actual streaming response generator"""
response = await self.client.send(self.anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content)
error_detail = error_json.get("error", "Unknown error from Anthropic")
except json.JSONDecodeError:
error_detail = {
"error": "Failed to decode error response from Anthropic"
}
raise HTTPException(status_code=response.status_code, detail=error_detail)
# iterate over the response stream
async for chunk in response.aiter_bytes():
yield chunk
async def on_chunk(self, chunk):
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
return
# process chunk and extend the merged_response
process_chunk(decoded_chunk, self.merged_response)
# on last stream chunk, run output guardrails
if (
"event: message_stop" in decoded_chunk
and self.context.config
and self.context.config.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"type": "error",
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": self.guardrails_execution_result,
},
}
)
# yield an extra error chunk (without preventing the original chunk to go through after,
# so client gets the proper message_stop event still)
return ExtraItem(
value=f"event: error\ndata: {error_chunk}\n\n".encode()
)
async def on_end(self):
"""on_end: send full merged response to the exploree (if configured)"""
# don't block on the response from explorer (.create_task)
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
async def handle_streaming_response(
context: RequestContextData,
@@ -211,63 +445,15 @@ async def handle_streaming_response(
anthropic_request: httpx.Request,
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
merged_response = {}
response = InstrumentedAnthropicStreamingResposne(
context=context,
client=client,
anthropic_request=anthropic_request,
)
response = await client.send(anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content)
error_detail = error_json.get("error", "Unknown error from Anthropic")
except json.JSONDecodeError:
error_detail = {"error": "Failed to decode error response from Anthropic"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
async def event_generator() -> Any:
async for chunk in response.aiter_bytes():
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
continue
process_chunk(decoded_chunk, merged_response)
if (
"event: message_stop" in decoded_chunk
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"type": "error",
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
},
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
yield f"event: error\ndata: {error_chunk}\n\n".encode()
return
yield chunk
if context.dataset_name:
# Push to Explorer - don't block on the response
asyncio.create_task(push_to_explorer(context, merged_response))
generator = event_generator()
return StreamingResponse(generator, media_type="text/event-stream")
return StreamingResponse(
response.instrumented_event_generator(), media_type="text/event-stream"
)
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:

View File

@@ -2,7 +2,7 @@
import asyncio
import json
from typing import Any, Optional
from typing import Any, Literal, Optional
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -15,8 +15,16 @@ from common.constants import (
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from converters.gemini_to_invariant import convert_request, convert_response
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
InstrumentedStreamingResponse,
Replacement,
preload_guardrails,
check_guardrails,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardails import check_guardrails, preload_guardrails
from integrations.guardrails import check_guardrails, preload_guardrails
gateway = APIRouter()
@@ -82,13 +90,169 @@ async def gemini_generate_content_gateway(
client,
gemini_request,
)
response = await client.send(gemini_request)
return await handle_non_streaming_response(
context,
response,
client,
gemini_request,
)
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
# Store the progressively merged response
self.merged_response = {
"candidates": [{"content": {"parts": []}, "finishReason": None}]
}
# guardrailing execution result (if any)
self.guardrails_execution_result: Optional[dict[str, Any]] = None
def make_refusal(
self,
location: Literal["request", "response"],
guardrails_execution_result: dict[str, Any],
) -> dict:
return {
"candidates": [
{
"content": {
"parts": [
{
"text": f"[Invariant] The {location} did not pass the guardrails",
}
],
}
}
],
"error": {
"code": 400,
"message": f"[Invariant] The {location} did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
"promptFeedback": {
"blockReason": "SAFETY",
"block_reason_message": f"[Invariant] The {location} did not pass the guardrails: "
+ json.dumps(guardrails_execution_result),
"safetyRatings": [
{
"category": "HARM_CATEGORY_UNSPECIFIED",
"probability": "HIGH",
"blocked": True,
}
],
},
}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
self.make_refusal("request", self.guardrails_execution_result)
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True
)
async def event_generator(self):
response = await self.client.send(self.gemini_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content.decode("utf-8"))
error_detail = error_json.get("error", "Unknown error from Gemini API")
except json.JSONDecodeError:
error_detail = {"error": "Failed to parse Gemini error response"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
async for chunk in response.aiter_bytes():
yield chunk
async def on_chunk(self, chunk):
chunk_text = chunk.decode().strip()
if not chunk_text:
return
# Parse and update merged_response incrementally
process_chunk_text(self.merged_response, chunk_text)
# runs on the last stream item
if (
self.merged_response.get("candidates", [])
and self.merged_response.get("candidates")[0].get("finishReason", "")
and self.context.config
and self.context.config.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
self.make_refusal("response", self.guardrails_execution_result)
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
return ExtraItem(
value=f"data: {error_chunk}\r\n\r\n".encode(),
# for Gemini we have to end the stream prematurely, as the client SDK
# will not stop streaming when it encounters an error
end_of_stream=True,
)
async def on_end(self):
"""Runs when the stream ends."""
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
async def stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
@@ -96,76 +260,21 @@ async def stream_response(
) -> Response:
"""Handles streaming the Gemini response to the client"""
response = await client.send(gemini_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content.decode("utf-8"))
error_detail = error_json.get("error", "Unknown error from Gemini API")
except json.JSONDecodeError:
error_detail = {"error": "Failed to parse Gemini error response"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
response = InstrumentedStreamingGeminiResponse(
context=context,
client=client,
gemini_request=gemini_request,
)
async def event_generator() -> Any:
# Store the progressively merged response
merged_response = {
"candidates": [{"content": {"parts": []}, "finishReason": None}]
}
async for chunk in response.aiter_bytes():
chunk_text = chunk.decode().strip()
if not chunk_text:
continue
# Parse and update merged_response incrementally
process_chunk_text(merged_response, chunk_text)
if (
merged_response.get("candidates", [])
and merged_response.get("candidates")[0].get("finishReason", "")
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
yield f"data: {error_chunk}\n\n".encode()
return
# Yield chunk immediately to the client
async def event_generator():
async for chunk in response.instrumented_event_generator():
yield chunk
print("chunk", chunk)
if context.dataset_name:
# Push to Explorer - don't block on the response
asyncio.create_task(
push_to_explorer(
context,
merged_response,
)
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)
def process_chunk_text(
@@ -281,53 +390,165 @@ async def push_to_explorer(
)
class InstrumentedGeminiResponse(InstrumentedResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
# response data
self.response: Optional[httpx.Response] = None
self.response_json: Optional[dict[str, Any]] = None
# guardrails execution result (if any)
self.guardrails_execution_result: Optional[dict[str, Any]] = None
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
"prompt_feedback": {
"blockReason": "SAFETY",
"safetyRatings": [
{
"category": "HARM_CATEGORY_UNSPECIFIED",
"probability": 0.0,
"blocked": True,
}
],
},
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return Replacement(
Response(
content=error_chunk,
status_code=400,
media_type="application/json",
headers={
"Content-Type": "application/json",
},
)
)
async def request(self):
self.response = await self.client.send(self.gemini_request)
response_string = self.response.text
response_code = self.response.status_code
try:
self.response_json = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail="Invalid JSON response received from Gemini API",
) from e
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=self.response_json.get("error", "Unknown error from Gemini API"),
)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(self.response.headers),
)
async def on_end(self):
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
if self.context.config and self.context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.response_json
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
}
)
response_code = 400
if self.context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context,
self.response_json,
guardrails_execution_result,
)
)
return Replacement(
Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(self.response.headers),
)
)
# Otherwise, also push to Explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context, self.response_json, guardrails_execution_result
)
)
async def handle_non_streaming_response(
context: RequestContextData,
response: httpx.Response,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
) -> Response:
"""Handles non-streaming Gemini responses"""
try:
response_json = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail="Invalid JSON response received from Gemini API",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=response_json.get("error", "Unknown error from Gemini API"),
)
guardrails_execution_result = {}
response_string = json.dumps(response_json)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, response_json
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
}
)
response_code = 400
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(context, response_json, guardrails_execution_result)
)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(response.headers),
response = InstrumentedGeminiResponse(
context=context,
client=client,
gemini_request=gemini_request,
)
return await response.instrumented_request()

View File

@@ -13,7 +13,13 @@ from common.constants import (
IGNORED_HEADERS,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardails import check_guardrails, preload_guardrails
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
InstrumentedStreamingResponse,
check_guardrails,
preload_guardrails,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
@@ -74,19 +80,159 @@ async def openai_chat_completions_gateway(
asyncio.create_task(preload_guardrails(context))
if request_json.get("stream", False):
return await stream_response(
return await handle_stream_response(
context,
client,
open_ai_request,
)
response = await client.send(open_ai_request)
return await handle_non_streaming_response(
context,
response,
)
return await handle_non_stream_response(context, client, open_ai_request)
async def stream_response(
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
"""
Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
"""
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
self.merged_response = {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
# Each chunk in the stream contains a list called "choices" each entry in the list
# has an index.
# A choice has a field called "delta" which may contain a list called "tool_calls".
# Maps the choice index in the stream to the index in the merged_response["choices"] list
self.choice_mapping_by_index = {}
# Combines the choice index and tool call index to uniquely identify a tool call
self.tool_call_mapping_by_index = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"data: {error_chunk}\n\n".encode(),
end_of_stream=True,
)
async def on_chunk(self, chunk):
# process and check each chunk
chunk_text = chunk.decode().strip()
if not chunk_text:
return
# Process the chunk
# This will update merged_response with the data from the chunk
process_chunk_text(
chunk_text,
self.merged_response,
self.choice_mapping_by_index,
self.tool_call_mapping_by_index,
)
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
if (
"data: [DONE]" in chunk_text
and self.context.config
and self.context.config.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": self.guardrails_execution_result,
}
}
)
# yield an extra error chunk (without preventing the original chunk to go through after)
return ExtraItem(f"data: {error_chunk}\n\n".encode())
# push will happen in on_end
async def on_end(self):
"""Sends full merged response to the exploree."""
# don't block on the response from explorer (.create_task)
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context, self.merged_response, self.guardrails_execution_result
)
)
async def event_generator(self):
"""
Actual OpenAI stream response.
"""
response = await self.client.send(self.open_ai_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content.decode("utf-8"))
error_detail = error_json.get("error", "Unknown error from OpenAI API")
except json.JSONDecodeError:
error_detail = {"error": "Failed to parse OpenAI error response"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
# stream out chunks
async for chunk in response.aiter_bytes():
yield chunk
async def handle_stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
@@ -98,89 +244,15 @@ async def stream_response(
It is sent to the Invariant Explorer at the end of the stream
"""
response = await client.send(open_ai_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content.decode("utf-8"))
error_detail = error_json.get("error", "Unknown error from OpenAI API")
except json.JSONDecodeError:
error_detail = {"error": "Failed to parse OpenAI error response"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
response = InstrumentedOpenAIStreamResponse(
context,
client,
open_ai_request,
)
async def event_generator() -> Any:
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
merged_response = {
"id": None,
"object": "chat.completion",
"created": None,
"model": None,
"choices": [],
"usage": None,
}
# Each chunk in the stream contains a list called "choices" each entry in the list
# has an index.
# A choice has a field called "delta" which may contain a list called "tool_calls".
# Maps the choice index in the stream to the index in the merged_response["choices"] list
choice_mapping_by_index = {}
# Combines the choice index and tool call index to uniquely identify a tool call
tool_call_mapping_by_index = {}
async for chunk in response.aiter_bytes():
chunk_text = chunk.decode().strip()
if not chunk_text:
continue
# Process the chunk
# This will update merged_response with the data from the chunk
process_chunk_text(
chunk_text,
merged_response,
choice_mapping_by_index,
tool_call_mapping_by_index,
)
# Check guardrails on the last chunk.
if (
"data: [DONE]" in chunk_text
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
yield f"data: {error_chunk}\n\n".encode()
return
# Yield chunk to the client
yield chunk
# Send full merged response to the explorer
# Don't block on the response from explorer
if context.dataset_name:
asyncio.create_task(push_to_explorer(context, merged_response))
return StreamingResponse(event_generator(), media_type="text/event-stream")
return StreamingResponse(
response.instrumented_event_generator(), media_type="text/event-stream"
)
def initialize_merged_response() -> dict[str, Any]:
@@ -329,7 +401,7 @@ def create_metadata(
{
key: value
for key, value in merged_response.items()
if key in ("usage", "model")
if key in ("usage", "model") and merged_response.get(key) is not None
}
)
return metadata
@@ -364,11 +436,13 @@ async def push_to_explorer(
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any]
context: RequestContextData, json_response: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in json_response.get("choices", [])]
if json_response is not None:
messages += [choice["message"] for choice in json_response.get("choices", [])]
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
@@ -379,49 +453,165 @@ async def get_guardrails_check_result(
return guardrails_execution_result
async def handle_non_streaming_response(
context: RequestContextData, response: httpx.Response
class InstrumentedOpenAIResponse(InstrumentedResponse):
"""
Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
"""
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
):
super().__init__()
# request parameters
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.open_ai_request: httpx.Request = open_ai_request
# request outputs
self.response: Optional[httpx.Response] = None
self.json_response: Optional[dict[str, Any]] = None
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
async def on_start(self):
"""Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)"""
if self.context.config and self.context.config.guardrails:
# block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context
)
if self.guardrails_execution_result.get("errors", []):
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# replace the response with the error message
return ExtraItem(
Response(
content=json.dumps(
{
"error": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
),
status_code=400,
media_type="application/json",
),
end_of_stream=True,
)
async def request(self):
"""Actual OpenAI request."""
self.response = await self.client.send(self.open_ai_request)
try:
self.json_response = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail="Invalid JSON response received from OpenAI API",
) from e
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
)
response_string = json.dumps(self.json_response)
response_code = self.response.status_code
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(self.response.headers),
)
async def on_end(self):
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
# these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed)
# nevertheless, we check for them to avoid any potential issues
assert (
self.response is not None
), "on_end called before 'self.response' was available"
assert (
self.json_response is not None
), "on_end called before 'self.json_response' was available"
# extract original response status code
response_code = self.response.status_code
# if we have guardrails, check the response
if self.context.config and self.context.config.guardrails:
# run guardrails again, this time on request + response
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.json_response
)
if self.guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": self.guardrails_execution_result,
}
)
response_code = 400
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.guardrails_execution_result,
)
)
# replace the response with the error message
return ExtraItem(
Response(
content=response_string,
status_code=response_code,
media_type="application/json",
),
)
# Push annotated trace to the explorer in any case - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
# include any guardrailing errors if available
self.guardrails_execution_result,
)
)
async def handle_non_stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
) -> Response:
"""Handles non-streaming OpenAI responses"""
try:
json_response = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail="Invalid JSON response received from OpenAI API",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from OpenAI API"),
)
guardrails_execution_result = {}
response_string = json.dumps(json_response)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, json_response
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
)
response_code = 400
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(context, json_response, guardrails_execution_result)
)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(response.headers),
# # execute instrumented request
# return await instrumentor.execute(send_request())
response = InstrumentedOpenAIResponse(
context,
client,
open_ai_request,
)
return await response.instrumented_request()

View File

@@ -15,7 +15,7 @@ UVICORN_PORT=${PORT:-8000}
# using 'exec' belows ensures that signals like SIGTERM are passed to the child process
# and not the shell script itself (important when running in a container)
if [ "$DEV_MODE" = "true" ]; then
exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT --reload
exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT --reload --reload-dir /srv/resources --reload-dir /srv/gateway
else
exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT
fi

View File

@@ -238,3 +238,97 @@ async def test_tool_call_guardrail_from_file(
== "get_capital is called with Germany as argument"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_input_from_guardrail_from_file(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test input guardrail enforcement with Anthropic."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/anthropic",
)
request = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 100,
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
}
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
_ = client.messages.create(**request, stream=False)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
)
else:
with pytest.raises(APIStatusError) as exc_info:
chat_response = client.messages.create(**request, stream=True)
for _ in chat_response:
pass
assert (
"[Invariant] The request did not pass the guardrails"
in exc_info.value.message
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value.body
)
if push_to_explorer:
time.sleep(2)
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
# in case of input guardrailing, the pushed trace will not contain a response
trace = trace_response.json()
assert len(trace["messages"]) == 1, "Only the user message should be present"
assert trace["messages"][0] == {
"role": "user",
"content": "Tell me more about Fight Club.",
}
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"]
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)

View File

@@ -63,8 +63,13 @@ async def test_message_content_guardrail_from_file(
else:
response = client.models.generate_content_stream(**request)
for chunk in response:
assert "Dublin" not in str(chunk)
assert_is_streamed_refusal(
response,
[
"[Invariant] The response did not pass the guardrails",
"Dublin detected in the response",
],
)
if push_to_explorer:
# Wait for the trace to be saved
@@ -172,8 +177,13 @@ async def test_tool_call_guardrail_from_file(
**request,
)
for chunk in response:
assert "Madrid" not in str(chunk)
assert_is_streamed_refusal(
response,
[
"[Invariant] The response did not pass the guardrails",
"get_capital is called with Germany as argument",
],
)
if push_to_explorer:
# Wait for the trace to be saved
@@ -219,3 +229,122 @@ async def test_tool_call_guardrail_from_file(
== "get_capital is called with Germany as argument"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_input_from_guardrail_from_file(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test input guardrail enforcement with Gemini."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"headers": {
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
},
)
request = {
"model": "gemini-2.0-flash",
"contents": "Tell me more about Fight Club.",
"config": {
"maxOutputTokens": 200,
},
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
client.models.generate_content(**request)
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
)
else:
response = client.models.generate_content_stream(**request)
assert_is_streamed_refusal(
response,
[
"[Invariant] The request did not pass the guardrails",
"Users must not mention the magic phrase 'Fight Club'",
],
)
if push_to_explorer:
time.sleep(2)
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 1
assert trace["messages"][0] == {
"role": "user",
"content": [{"type": "text", "text": "Tell me more about Fight Club."}],
}
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"]
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
def is_refusal(chunk):
return (
len(chunk.candidates) == 1
and chunk.candidates[0].content.parts[0].text.startswith("[Invariant]")
and chunk.prompt_feedback is not None
and "BlockedReason.SAFETY" in str(chunk.prompt_feedback)
)
def assert_is_streamed_refusal(response, expected_message_components: list[str]):
"""
Validates that the streamed response contains a refusal at the end (or as only message).
"""
num_chunks = 0
for c in response:
num_chunks += 1
assert num_chunks >= 1, "Expected at least one chunk"
# last chunk must be a refusal
assert is_refusal(c)
for emc in expected_message_components:
assert (
emc in c.model_dump_json()
), f"Expected message component {emc} not found in refusal message: {c.model_dump_json()}"

View File

@@ -244,3 +244,108 @@ async def test_tool_call_guardrail_from_file(
== "get_capital is called with Germany as argument"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_input_from_guardrail_from_file(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test the message content guardrail."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
request = {
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
}
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
)
else:
with pytest.raises(APIError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=True,
)
for _ in chat_response:
pass
assert (
"[Invariant] The request did not pass the guardrails"
in exc_info.value.message
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value.body
)
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
# in case of input guardrailing, the pushed trace will not contain a response
assert len(trace["messages"]) == 1
assert trace["messages"][0] == {
"role": "user",
"content": "Tell me more about Fight Club.",
}
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"]
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)

View File

@@ -13,4 +13,10 @@ raise "Dublin detected in the response" if:
raise "get_capital is called with Germany as argument" if:
(call: ToolCall)
call is tool:get_capital
call.function.arguments["country_name"] == "Germany"
call.function.arguments["country_name"] == "Germany"
# For input guardrailing specifically
raise "Users must not mention the magic phrase 'Fight Club'" if:
(msg: Message)
msg.role == "user"
"Fight Club" in msg.content