mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Pipelined Guardrails (#32)
* initial draft: pipelined guardrails * documentation on stream instrumentation * more comments * fix: return earlier * non-streaming case * handle non-streaming case * fix more cases * simplify request instrumentation * improve comments * fix import issues * extend tests for input guardrailing * anthropic integration of pipelined and pre-guardrailing * fix gemini streamed refusal
This commit is contained in:
committed by
GitHub
parent
4671c8b67e
commit
7c0bb957fb
2
.env
2
.env
@@ -3,4 +3,4 @@
|
||||
# If you want to push to a local instance of explorer, then specify the app-api docker container name like:
|
||||
# http://<app-api-docker-container-name>:8000 to push to the local explorer instance.
|
||||
INVARIANT_API_URL=https://explorer.invariantlabs.ai
|
||||
GUADRAILS_API_URL=https://guardrail.invariantnet.com
|
||||
GUADRAILS_API_URL=https://explorer.invariantlabs.ai
|
||||
|
||||
7
example_policy.gr
Normal file
7
example_policy.gr
Normal file
@@ -0,0 +1,7 @@
|
||||
from invariant.detectors import prompt_injection
|
||||
|
||||
raise "Don't say 'Hello'" if:
|
||||
(msg: Message)
|
||||
msg.role == "user"
|
||||
prompt_injection(msg.content)
|
||||
# "Hello" in msg.content
|
||||
@@ -4,8 +4,6 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
|
||||
from integrations.guardails import _preload
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
|
||||
@@ -20,6 +18,8 @@ class GatewayConfig:
|
||||
Loads the guardrails from the file specified in GUARDRAILS_FILE_PATH.
|
||||
Returns the guardrails file content as a string.
|
||||
"""
|
||||
from integrations.guardrails import _preload
|
||||
|
||||
guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "")
|
||||
if not guardrails_file:
|
||||
print("[warning: GUARDRAILS_FILE_PATH is not set. Using empty guardrails]")
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
"""Utility functions for Guardrails execution."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
|
||||
DEFAULT_API_URL = "https://guardrail.invariantnet.com"
|
||||
|
||||
|
||||
# Timestamps of last API calls per guardrails string
|
||||
_guardrails_cache = {}
|
||||
# Locks per guardrails string
|
||||
_guardrails_locks = {}
|
||||
|
||||
|
||||
def rate_limit(expiration_time: int = 3600):
|
||||
"""
|
||||
Decorator to limit API calls to once per expiration_time seconds
|
||||
per unique guardrails string.
|
||||
|
||||
Args:
|
||||
expiration_time (int): Time in seconds to cache the guardrails.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(guardrails: str, *args, **kwargs):
|
||||
now = time.time()
|
||||
|
||||
# Get or create a per-guardrail lock
|
||||
if guardrails not in _guardrails_locks:
|
||||
_guardrails_locks[guardrails] = asyncio.Lock()
|
||||
guardrail_lock = _guardrails_locks[guardrails]
|
||||
|
||||
async with guardrail_lock:
|
||||
last_called = _guardrails_cache.get(guardrails)
|
||||
|
||||
if last_called and (now - last_called < expiration_time):
|
||||
# Skipping API call: Guardrails '{guardrails}' already
|
||||
# preloaded within expiration_time
|
||||
return
|
||||
|
||||
# Update cache timestamp
|
||||
_guardrails_cache[guardrails] = now
|
||||
|
||||
try:
|
||||
await func(guardrails, *args, **kwargs)
|
||||
finally:
|
||||
_guardrails_locks.pop(guardrails, None)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@rate_limit(3600) # Don't preload the same guardrails string more than once per hour
|
||||
async def _preload(guardrails: str, invariant_authorization: str) -> None:
|
||||
"""
|
||||
Calls the Guardrails API to preload the provided policy for faster checking later.
|
||||
|
||||
Args:
|
||||
guardrails (str): The guardrails to preload.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/load",
|
||||
json={"policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
|
||||
async def preload_guardrails(context: "RequestContextData") -> None:
|
||||
"""
|
||||
Preloads the guardrails for faster checking later.
|
||||
|
||||
Args:
|
||||
context: RequestContextData object.
|
||||
"""
|
||||
if not context.config or not context.config.guardrails:
|
||||
return
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
_preload(context.config.guardrails, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
except Exception as e:
|
||||
print(f"Error scheduling preload_guardrails task: {e}")
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (str): The guardrails to check against.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
|
||||
Returns:
|
||||
Dict: Response containing guardrail check results.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
return {"error": str(e)}
|
||||
358
gateway/integrations/guardrails.py
Normal file
358
gateway/integrations/guardrails.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Utility functions for Guardrails execution."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
# Timestamps of last API calls per guardrails string
|
||||
_guardrails_cache = {}
|
||||
# Locks per guardrails string
|
||||
_guardrails_locks = {}
|
||||
|
||||
|
||||
def rate_limit(expiration_time: int = 3600):
|
||||
"""
|
||||
Decorator to limit API calls to once per expiration_time seconds
|
||||
per unique guardrails string.
|
||||
|
||||
Args:
|
||||
expiration_time (int): Time in seconds to cache the guardrails.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(guardrails: str, *args, **kwargs):
|
||||
now = time.time()
|
||||
|
||||
# Get or create a per-guardrail lock
|
||||
if guardrails not in _guardrails_locks:
|
||||
_guardrails_locks[guardrails] = asyncio.Lock()
|
||||
guardrail_lock = _guardrails_locks[guardrails]
|
||||
|
||||
async with guardrail_lock:
|
||||
last_called = _guardrails_cache.get(guardrails)
|
||||
|
||||
if last_called and (now - last_called < expiration_time):
|
||||
# Skipping API call: Guardrails '{guardrails}' already
|
||||
# preloaded within expiration_time
|
||||
return
|
||||
|
||||
# Update cache timestamp
|
||||
_guardrails_cache[guardrails] = now
|
||||
|
||||
try:
|
||||
await func(guardrails, *args, **kwargs)
|
||||
finally:
|
||||
_guardrails_locks.pop(guardrails, None)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@rate_limit(3600) # Don't preload the same guardrails string more than once per hour
|
||||
async def _preload(guardrails: str, invariant_authorization: str) -> None:
|
||||
"""
|
||||
Calls the Guardrails API to preload the provided policy for faster checking later.
|
||||
|
||||
Args:
|
||||
guardrails (str): The guardrails to preload.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/load",
|
||||
json={"policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
|
||||
async def preload_guardrails(context: "RequestContextData") -> None:
|
||||
"""
|
||||
Preloads the guardrails for faster checking later.
|
||||
|
||||
Args:
|
||||
context: RequestContextData object.
|
||||
"""
|
||||
if not context.config or not context.config.guardrails:
|
||||
return
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
_preload(context.config.guardrails, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
except Exception as e:
|
||||
print(f"Error scheduling preload_guardrails task: {e}")
|
||||
|
||||
|
||||
class ExtraItem:
|
||||
"""
|
||||
Return this class in a instrumented stream callback, to yield an extra item in the resulting stream.
|
||||
"""
|
||||
|
||||
def __init__(self, value, end_of_stream=False):
|
||||
self.value = value
|
||||
self.end_of_stream = end_of_stream
|
||||
|
||||
def __str__(self):
|
||||
return f"<ExtraItem value={self.value} end_of_stream={self.end_of_stream}>"
|
||||
|
||||
|
||||
class Replacement(ExtraItem):
|
||||
"""
|
||||
Like ExtraItem, but used to replace the full request result in case of 'InstrumentedResponse'.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
super().__init__(value, end_of_stream=True)
|
||||
|
||||
def __str__(self):
|
||||
return f"<Replacement value={self.value}>"
|
||||
|
||||
|
||||
class InstrumentedStreamingResponse:
|
||||
def __init__(self):
|
||||
# request statistics
|
||||
self.stat_token_times = []
|
||||
self.stat_before_time = None
|
||||
self.stat_after_time = None
|
||||
|
||||
self.stat_first_item_time = None
|
||||
|
||||
async def on_chunk(self, chunk: Any) -> ExtraItem | None:
|
||||
"""
|
||||
This called will be called on every chunk (async).
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_start(self) -> ExtraItem | None:
|
||||
"""
|
||||
Decorator to register a listener for start events.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_end(self) -> ExtraItem | None:
|
||||
"""
|
||||
Decorator to register a listener for end events.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
Streams the async iterable and invokes all instrumented hooks.
|
||||
|
||||
Args:
|
||||
async_iterable: An async iterable to stream.
|
||||
|
||||
Yields:
|
||||
The streamed data.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented in a subclass.")
|
||||
|
||||
async def instrumented_event_generator(self):
|
||||
"""
|
||||
Streams the async iterable and invokes all instrumented hooks.
|
||||
|
||||
Args:
|
||||
async_iterable: An async iterable to stream.
|
||||
|
||||
Yields:
|
||||
The streamed data.
|
||||
"""
|
||||
try:
|
||||
start = time.time()
|
||||
|
||||
# schedule on_start which can be run concurrently
|
||||
start_task = asyncio.create_task(self.on_start(), name="instrumentor:start")
|
||||
|
||||
# create async iterator from async_iterable
|
||||
aiterable = aiter(self.event_generator())
|
||||
|
||||
# [STAT] capture start time of first item
|
||||
start_first_item_request = time.time()
|
||||
|
||||
# waits for first item of the iterable
|
||||
async def wait_for_first_item():
|
||||
nonlocal start_first_item_request, aiterable
|
||||
|
||||
r = await aiterable.__anext__()
|
||||
if self.stat_first_item_time is None:
|
||||
# [STAT] capture time to first item
|
||||
self.stat_first_item_time = time.time() - start_first_item_request
|
||||
return r
|
||||
|
||||
next_item_task = asyncio.create_task(
|
||||
wait_for_first_item(), name="instrumentor:next:first"
|
||||
)
|
||||
|
||||
# check if 'start_task' yields an extra item
|
||||
if extra_item := await start_task:
|
||||
# yield extra value before any real items
|
||||
yield extra_item.value
|
||||
# stop the stream if end_of_stream is True
|
||||
if extra_item.end_of_stream:
|
||||
# if first item is already available
|
||||
if not next_item_task.done():
|
||||
# cancel the task
|
||||
next_item_task.cancel()
|
||||
# [STAT] capture time to first item to be now +0.01
|
||||
if self.stat_first_item_time is None:
|
||||
self.stat_first_item_time = (
|
||||
time.time() - start_first_item_request
|
||||
) + 0.01
|
||||
# don't wait for the first item if end_of stream is True
|
||||
return
|
||||
|
||||
# [STAT] capture before time stamp
|
||||
self.stat_before_time = time.time() - start
|
||||
|
||||
while True:
|
||||
# wait for first item
|
||||
try:
|
||||
item = await next_item_task
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
# schedule next item
|
||||
next_item_task = asyncio.create_task(
|
||||
aiterable.__anext__(), name="instrumentor:next"
|
||||
)
|
||||
|
||||
# [STAT] capture token time stamp
|
||||
if len(self.stat_token_times) == 0:
|
||||
self.stat_token_times.append(time.time() - start)
|
||||
else:
|
||||
self.stat_token_times.append(
|
||||
time.time() - start - sum(self.stat_token_times)
|
||||
)
|
||||
|
||||
if extra_item := await self.on_chunk(item):
|
||||
yield extra_item.value
|
||||
# if end_of_stream is True, stop the stream
|
||||
if extra_item.end_of_stream:
|
||||
# cancel next task
|
||||
next_item_task.cancel()
|
||||
return
|
||||
|
||||
# yield item
|
||||
yield item
|
||||
|
||||
# run on_end, before closing the stream (may yield an extra value)
|
||||
if extra_item := await self.on_end():
|
||||
# yield extra value before any real items
|
||||
yield extra_item.value
|
||||
# we ignore end_of_stream here, because we are already at the end
|
||||
|
||||
# [STAT] capture after time stamp
|
||||
self.stat_after_time = time.time() - start
|
||||
finally:
|
||||
# [STAT] end all open intervals if not already closed
|
||||
if self.stat_after_time is None:
|
||||
self.stat_before_time = time.time() - start
|
||||
if self.stat_after_time is None:
|
||||
self.stat_after_time = 0
|
||||
if self.stat_first_item_time is None:
|
||||
self.stat_first_item_time = 0
|
||||
|
||||
# print statistics
|
||||
token_times_5_decimale = str([f"{x:.5f}" for x in self.stat_token_times])
|
||||
print(
|
||||
f"[STATS]\n [token times: {token_times_5_decimale} ({len(self.stat_token_times)})]"
|
||||
)
|
||||
print(f" [before: {self.stat_before_time:.2f}s] ")
|
||||
print(f" [time-to-first-item: {self.stat_first_item_time:.2f}s]")
|
||||
print(
|
||||
f" [zero-latency: {' TRUE' if self.stat_before_time < self.stat_first_item_time else 'FALSE'}]"
|
||||
)
|
||||
print(
|
||||
f" [extra-latency: {self.stat_before_time - self.stat_first_item_time:.2f}s]"
|
||||
)
|
||||
print(f" [after: {self.stat_after_time:.2f}s]")
|
||||
if len(self.stat_token_times) > 0:
|
||||
print(
|
||||
f" [average token time: {sum(self.stat_token_times) / len(self.stat_token_times):.2f}s]"
|
||||
)
|
||||
print(f" [total: {time.time() - start:.2f}s]")
|
||||
|
||||
|
||||
class InstrumentedResponse(InstrumentedStreamingResponse):
|
||||
"""
|
||||
A class to instrument an async request with hooks for concurrent
|
||||
pre-processing and post-processing (input and output guardrailing).
|
||||
"""
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
We implement the 'event_generator' as a single item stream,
|
||||
where the item is the full result of the request.
|
||||
"""
|
||||
yield await self.request()
|
||||
|
||||
async def request(self):
|
||||
"""
|
||||
This method should be implemented in a subclass to perform the actual request.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented in a subclass.")
|
||||
|
||||
async def instrumented_request(self):
|
||||
"""
|
||||
Returns the 'Response' object of the request, after applying all instrumented hooks.
|
||||
"""
|
||||
results = [r async for r in self.instrumented_event_generator()]
|
||||
assert len(results) >= 1, "InstrumentedResponse must yield at least one item"
|
||||
|
||||
# we return the last item, in case the end callback yields an extra item. Then,
|
||||
# don't return the actual result but the 'end' result, e.g. for output guardrailing.
|
||||
return results[-1]
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (str): The guardrails to check against.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
|
||||
Returns:
|
||||
Dict: Response containing guardrail check results.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
if not result.is_success:
|
||||
raise Exception(
|
||||
f"Guardrails check failed: {result.status_code} - {result.text}"
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
return {"error": str(e)}
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from regex import R
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
@@ -18,7 +19,14 @@ from converters.anthropic_to_invariant import (
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
Replacement,
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -85,8 +93,7 @@ async def anthropic_v1_messages_gateway(
|
||||
|
||||
if request_json.get("stream"):
|
||||
return await handle_streaming_response(context, client, anthropic_request)
|
||||
response = await client.send(anthropic_request)
|
||||
return await handle_non_streaming_response(context, response)
|
||||
return await handle_non_streaming_response(context, client, anthropic_request)
|
||||
|
||||
|
||||
def create_metadata(
|
||||
@@ -110,7 +117,8 @@ def combine_request_and_response_messages(
|
||||
{"role": "system", "content": context.request_json.get("system")}
|
||||
)
|
||||
messages.extend(context.request_json.get("messages", []))
|
||||
messages.append(json_response)
|
||||
if len(json_response) > 0:
|
||||
messages.append(json_response)
|
||||
return messages
|
||||
|
||||
|
||||
@@ -154,56 +162,282 @@ async def push_to_explorer(
|
||||
)
|
||||
|
||||
|
||||
class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_string: Optional[str] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we prevent the request from going through
|
||||
# and return an error instead
|
||||
return Replacement(
|
||||
Response(
|
||||
content=error_chunk,
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
self.response = await self.client.send(self.anthropic_request)
|
||||
|
||||
try:
|
||||
json_response = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}",
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
|
||||
self.json_response = json_response
|
||||
self.response_string = json.dumps(json_response)
|
||||
|
||||
return self._make_response(
|
||||
content=self.response_string,
|
||||
status_code=self.response.status_code,
|
||||
)
|
||||
|
||||
def _make_response(self, content: str, status_code: int):
|
||||
"""Creates a new Response object with the correct headers and content"""
|
||||
assert self.response is not None, "response is None"
|
||||
|
||||
updated_headers = self.response.headers.copy()
|
||||
updated_headers.pop("Content-Length", None)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=status_code,
|
||||
media_type="application/json",
|
||||
headers=dict(updated_headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
|
||||
# ensure the response data is available
|
||||
assert self.response is not None, "response is None"
|
||||
assert self.json_response is not None, "json_response is None"
|
||||
assert self.response_string is not None, "response_string is None"
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
guardrail_response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
|
||||
# push to explorer (if configured)
|
||||
if self.context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
return Replacement(
|
||||
self._make_response(
|
||||
content=guardrail_response_string,
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
|
||||
# push to explorer (if configured)
|
||||
if self.context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.json_response, guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
response: httpx.Response,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming Anthropic responses"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(json_response)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, json_response, guardrails_execution_result)
|
||||
)
|
||||
|
||||
updated_headers = response.headers.copy()
|
||||
updated_headers.pop("Content-Length", None)
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(updated_headers),
|
||||
response = InstrumentedAnthropicResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
anthropic_request=anthropic_request,
|
||||
)
|
||||
|
||||
return await response.instrumented_request()
|
||||
|
||||
|
||||
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
# response data
|
||||
self.merged_response = {}
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"event: error\ndata: {error_chunk}\n\n".encode(),
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""Actual streaming response generator"""
|
||||
response = await self.client.send(self.anthropic_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content)
|
||||
error_detail = error_json.get("error", "Unknown error from Anthropic")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {
|
||||
"error": "Failed to decode error response from Anthropic"
|
||||
}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
# iterate over the response stream
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
return
|
||||
|
||||
# process chunk and extend the merged_response
|
||||
process_chunk(decoded_chunk, self.merged_response)
|
||||
|
||||
# on last stream chunk, run output guardrails
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after,
|
||||
# so client gets the proper message_stop event still)
|
||||
return ExtraItem(
|
||||
value=f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""on_end: send full merged response to the exploree (if configured)"""
|
||||
# don't block on the response from explorer (.create_task)
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
context: RequestContextData,
|
||||
@@ -211,63 +445,15 @@ async def handle_streaming_response(
|
||||
anthropic_request: httpx.Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
merged_response = {}
|
||||
response = InstrumentedAnthropicStreamingResposne(
|
||||
context=context,
|
||||
client=client,
|
||||
anthropic_request=anthropic_request,
|
||||
)
|
||||
|
||||
response = await client.send(anthropic_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content)
|
||||
error_detail = error_json.get("error", "Unknown error from Anthropic")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to decode error response from Anthropic"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
async def event_generator() -> Any:
|
||||
async for chunk in response.aiter_bytes():
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
continue
|
||||
process_chunk(decoded_chunk, merged_response)
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
yield f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
return
|
||||
yield chunk
|
||||
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on the response
|
||||
asyncio.create_task(push_to_explorer(context, merged_response))
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
response.instrumented_event_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
@@ -15,8 +15,16 @@ from common.constants import (
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from converters.gemini_to_invariant import convert_request, convert_response
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
Replacement,
|
||||
preload_guardrails,
|
||||
check_guardrails,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
from integrations.guardrails import check_guardrails, preload_guardrails
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -82,13 +90,169 @@ async def gemini_generate_content_gateway(
|
||||
client,
|
||||
gemini_request,
|
||||
)
|
||||
response = await client.send(gemini_request)
|
||||
return await handle_non_streaming_response(
|
||||
context,
|
||||
response,
|
||||
client,
|
||||
gemini_request,
|
||||
)
|
||||
|
||||
|
||||
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
# Store the progressively merged response
|
||||
self.merged_response = {
|
||||
"candidates": [{"content": {"parts": []}, "finishReason": None}]
|
||||
}
|
||||
|
||||
# guardrailing execution result (if any)
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
|
||||
def make_refusal(
|
||||
self,
|
||||
location: Literal["request", "response"],
|
||||
guardrails_execution_result: dict[str, Any],
|
||||
) -> dict:
|
||||
return {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": f"[Invariant] The {location} did not pass the guardrails",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": f"[Invariant] The {location} did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
"promptFeedback": {
|
||||
"blockReason": "SAFETY",
|
||||
"block_reason_message": f"[Invariant] The {location} did not pass the guardrails: "
|
||||
+ json.dumps(guardrails_execution_result),
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_UNSPECIFIED",
|
||||
"probability": "HIGH",
|
||||
"blocked": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
self.make_refusal("request", self.guardrails_execution_result)
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
response = await self.client.send(self.gemini_request, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content.decode("utf-8"))
|
||||
error_detail = error_json.get("error", "Unknown error from Gemini API")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to parse Gemini error response"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
|
||||
# Parse and update merged_response incrementally
|
||||
process_chunk_text(self.merged_response, chunk_text)
|
||||
|
||||
# runs on the last stream item
|
||||
if (
|
||||
self.merged_response.get("candidates", [])
|
||||
and self.merged_response.get("candidates")[0].get("finishReason", "")
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
self.make_refusal("response", self.guardrails_execution_result)
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
return ExtraItem(
|
||||
value=f"data: {error_chunk}\r\n\r\n".encode(),
|
||||
# for Gemini we have to end the stream prematurely, as the client SDK
|
||||
# will not stop streaming when it encounters an error
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Runs when the stream ends."""
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
@@ -96,76 +260,21 @@ async def stream_response(
|
||||
) -> Response:
|
||||
"""Handles streaming the Gemini response to the client"""
|
||||
|
||||
response = await client.send(gemini_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content.decode("utf-8"))
|
||||
error_detail = error_json.get("error", "Unknown error from Gemini API")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to parse Gemini error response"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
response = InstrumentedStreamingGeminiResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
gemini_request=gemini_request,
|
||||
)
|
||||
|
||||
async def event_generator() -> Any:
|
||||
# Store the progressively merged response
|
||||
merged_response = {
|
||||
"candidates": [{"content": {"parts": []}, "finishReason": None}]
|
||||
}
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
continue
|
||||
|
||||
# Parse and update merged_response incrementally
|
||||
process_chunk_text(merged_response, chunk_text)
|
||||
|
||||
if (
|
||||
merged_response.get("candidates", [])
|
||||
and merged_response.get("candidates")[0].get("finishReason", "")
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_chunk}\n\n".encode()
|
||||
return
|
||||
|
||||
# Yield chunk immediately to the client
|
||||
async def event_generator():
|
||||
async for chunk in response.instrumented_event_generator():
|
||||
yield chunk
|
||||
print("chunk", chunk)
|
||||
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on the response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
)
|
||||
)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def process_chunk_text(
|
||||
@@ -281,53 +390,165 @@ async def push_to_explorer(
|
||||
)
|
||||
|
||||
|
||||
class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrails execution result (if any)
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
"prompt_feedback": {
|
||||
"blockReason": "SAFETY",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_UNSPECIFIED",
|
||||
"probability": 0.0,
|
||||
"blocked": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return Replacement(
|
||||
Response(
|
||||
content=error_chunk,
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
self.response = await self.client.send(self.gemini_request)
|
||||
|
||||
response_string = self.response.text
|
||||
response_code = self.response.status_code
|
||||
|
||||
try:
|
||||
self.response_json = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail="Invalid JSON response received from Gemini API",
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=self.response_json.get("error", "Unknown error from Gemini API"),
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.response_json
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
|
||||
if self.context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.response_json,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
return Replacement(
|
||||
Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
)
|
||||
|
||||
# Otherwise, also push to Explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.response_json, guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
response: httpx.Response,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming Gemini responses"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="Invalid JSON response received from Gemini API",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response_json.get("error", "Unknown error from Gemini API"),
|
||||
)
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(response_json)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, response_json
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, response_json, guardrails_execution_result)
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(response.headers),
|
||||
response = InstrumentedGeminiResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
gemini_request=gemini_request,
|
||||
)
|
||||
|
||||
return await response.instrumented_request()
|
||||
|
||||
@@ -13,7 +13,13 @@ from common.constants import (
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
@@ -74,19 +80,159 @@ async def openai_chat_completions_gateway(
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
|
||||
if request_json.get("stream", False):
|
||||
return await stream_response(
|
||||
return await handle_stream_response(
|
||||
context,
|
||||
client,
|
||||
open_ai_request,
|
||||
)
|
||||
response = await client.send(open_ai_request)
|
||||
return await handle_non_streaming_response(
|
||||
context,
|
||||
response,
|
||||
)
|
||||
|
||||
return await handle_non_stream_response(context, client, open_ai_request)
|
||||
|
||||
|
||||
async def stream_response(
|
||||
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
"""
|
||||
Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
|
||||
# merged_response will be updated with the data from the chunks in the stream
|
||||
# At the end of the stream, this will be sent to the explorer
|
||||
self.merged_response = {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
"model": None,
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
# Each chunk in the stream contains a list called "choices" each entry in the list
|
||||
# has an index.
|
||||
# A choice has a field called "delta" which may contain a list called "tool_calls".
|
||||
# Maps the choice index in the stream to the index in the merged_response["choices"] list
|
||||
self.choice_mapping_by_index = {}
|
||||
# Combines the choice index and tool call index to uniquely identify a tool call
|
||||
self.tool_call_mapping_by_index = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"data: {error_chunk}\n\n".encode(),
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
# process and check each chunk
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
|
||||
# Process the chunk
|
||||
# This will update merged_response with the data from the chunk
|
||||
process_chunk_text(
|
||||
chunk_text,
|
||||
self.merged_response,
|
||||
self.choice_mapping_by_index,
|
||||
self.tool_call_mapping_by_index,
|
||||
)
|
||||
|
||||
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
|
||||
if (
|
||||
"data: [DONE]" in chunk_text
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after)
|
||||
return ExtraItem(f"data: {error_chunk}\n\n".encode())
|
||||
|
||||
# push will happen in on_end
|
||||
|
||||
async def on_end(self):
|
||||
"""Sends full merged response to the exploree."""
|
||||
# don't block on the response from explorer (.create_task)
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.merged_response, self.guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
Actual OpenAI stream response.
|
||||
"""
|
||||
|
||||
response = await self.client.send(self.open_ai_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content.decode("utf-8"))
|
||||
error_detail = error_json.get("error", "Unknown error from OpenAI API")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to parse OpenAI error response"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
# stream out chunks
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def handle_stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
@@ -98,89 +244,15 @@ async def stream_response(
|
||||
It is sent to the Invariant Explorer at the end of the stream
|
||||
"""
|
||||
|
||||
response = await client.send(open_ai_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content.decode("utf-8"))
|
||||
error_detail = error_json.get("error", "Unknown error from OpenAI API")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to parse OpenAI error response"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
response = InstrumentedOpenAIStreamResponse(
|
||||
context,
|
||||
client,
|
||||
open_ai_request,
|
||||
)
|
||||
|
||||
async def event_generator() -> Any:
|
||||
# merged_response will be updated with the data from the chunks in the stream
|
||||
# At the end of the stream, this will be sent to the explorer
|
||||
merged_response = {
|
||||
"id": None,
|
||||
"object": "chat.completion",
|
||||
"created": None,
|
||||
"model": None,
|
||||
"choices": [],
|
||||
"usage": None,
|
||||
}
|
||||
# Each chunk in the stream contains a list called "choices" each entry in the list
|
||||
# has an index.
|
||||
# A choice has a field called "delta" which may contain a list called "tool_calls".
|
||||
# Maps the choice index in the stream to the index in the merged_response["choices"] list
|
||||
choice_mapping_by_index = {}
|
||||
# Combines the choice index and tool call index to uniquely identify a tool call
|
||||
tool_call_mapping_by_index = {}
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
continue
|
||||
|
||||
# Process the chunk
|
||||
# This will update merged_response with the data from the chunk
|
||||
process_chunk_text(
|
||||
chunk_text,
|
||||
merged_response,
|
||||
choice_mapping_by_index,
|
||||
tool_call_mapping_by_index,
|
||||
)
|
||||
|
||||
# Check guardrails on the last chunk.
|
||||
if (
|
||||
"data: [DONE]" in chunk_text
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_chunk}\n\n".encode()
|
||||
return
|
||||
|
||||
# Yield chunk to the client
|
||||
yield chunk
|
||||
|
||||
# Send full merged response to the explorer
|
||||
# Don't block on the response from explorer
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(push_to_explorer(context, merged_response))
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
response.instrumented_event_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
def initialize_merged_response() -> dict[str, Any]:
|
||||
@@ -329,7 +401,7 @@ def create_metadata(
|
||||
{
|
||||
key: value
|
||||
for key, value in merged_response.items()
|
||||
if key in ("usage", "model")
|
||||
if key in ("usage", "model") and merged_response.get(key) is not None
|
||||
}
|
||||
)
|
||||
return metadata
|
||||
@@ -364,11 +436,13 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
context: RequestContextData, json_response: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
|
||||
if json_response is not None:
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
@@ -379,49 +453,165 @@ async def get_guardrails_check_result(
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData, response: httpx.Response
|
||||
class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
"""
|
||||
Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# request outputs
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)"""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# replace the response with the error message
|
||||
return ExtraItem(
|
||||
Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
),
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
),
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
"""Actual OpenAI request."""
|
||||
self.response = await self.client.send(self.open_ai_request)
|
||||
|
||||
try:
|
||||
self.json_response = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail="Invalid JSON response received from OpenAI API",
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
response_string = json.dumps(self.json_response)
|
||||
response_code = self.response.status_code
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
|
||||
|
||||
# these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed)
|
||||
# nevertheless, we check for them to avoid any potential issues
|
||||
assert (
|
||||
self.response is not None
|
||||
), "on_end called before 'self.response' was available"
|
||||
assert (
|
||||
self.json_response is not None
|
||||
), "on_end called before 'self.json_response' was available"
|
||||
|
||||
# extract original response status code
|
||||
response_code = self.response.status_code
|
||||
|
||||
# if we have guardrails, check the response
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# run guardrails again, this time on request + response
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# replace the response with the error message
|
||||
return ExtraItem(
|
||||
Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
),
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer in any case - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
# include any guardrailing errors if available
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming OpenAI responses"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="Invalid JSON response received from OpenAI API",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(json_response)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, json_response, guardrails_execution_result)
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(response.headers),
|
||||
# # execute instrumented request
|
||||
# return await instrumentor.execute(send_request())
|
||||
response = InstrumentedOpenAIResponse(
|
||||
context,
|
||||
client,
|
||||
open_ai_request,
|
||||
)
|
||||
|
||||
return await response.instrumented_request()
|
||||
|
||||
@@ -15,7 +15,7 @@ UVICORN_PORT=${PORT:-8000}
|
||||
# using 'exec' belows ensures that signals like SIGTERM are passed to the child process
|
||||
# and not the shell script itself (important when running in a container)
|
||||
if [ "$DEV_MODE" = "true" ]; then
|
||||
exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT --reload
|
||||
exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT --reload --reload-dir /srv/resources --reload-dir /srv/gateway
|
||||
else
|
||||
exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT
|
||||
fi
|
||||
@@ -238,3 +238,97 @@ async def test_tool_call_guardrail_from_file(
|
||||
== "get_capital is called with Germany as argument"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_input_from_guardrail_from_file(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test input guardrail enforcement with Anthropic."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
|
||||
client = Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
)
|
||||
|
||||
request = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
_ = client.messages.create(**request, stream=False)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
else:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
chat_response = client.messages.create(**request, stream=True)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert (
|
||||
"[Invariant] The request did not pass the guardrails"
|
||||
in exc_info.value.message
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value.body
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
time.sleep(2)
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
# in case of input guardrailing, the pushed trace will not contain a response
|
||||
trace = trace_response.json()
|
||||
assert len(trace["messages"]) == 1, "Only the user message should be present"
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": "Tell me more about Fight Club.",
|
||||
}
|
||||
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"]
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
@@ -63,8 +63,13 @@ async def test_message_content_guardrail_from_file(
|
||||
|
||||
else:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
for chunk in response:
|
||||
assert "Dublin" not in str(chunk)
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The response did not pass the guardrails",
|
||||
"Dublin detected in the response",
|
||||
],
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Wait for the trace to be saved
|
||||
@@ -172,8 +177,13 @@ async def test_tool_call_guardrail_from_file(
|
||||
**request,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
assert "Madrid" not in str(chunk)
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The response did not pass the guardrails",
|
||||
"get_capital is called with Germany as argument",
|
||||
],
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Wait for the trace to be saved
|
||||
@@ -219,3 +229,122 @@ async def test_tool_call_guardrail_from_file(
|
||||
== "get_capital is called with Germany as argument"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_input_from_guardrail_from_file(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test input guardrail enforcement with Gemini."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"headers": {
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
},
|
||||
)
|
||||
|
||||
request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": "Tell me more about Fight Club.",
|
||||
"config": {
|
||||
"maxOutputTokens": 200,
|
||||
},
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(genai.errors.ClientError) as exc_info:
|
||||
client.models.generate_content(**request)
|
||||
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
else:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The request did not pass the guardrails",
|
||||
"Users must not mention the magic phrase 'Fight Club'",
|
||||
],
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
time.sleep(2)
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Tell me more about Fight Club."}],
|
||||
}
|
||||
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"]
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
def is_refusal(chunk):
|
||||
return (
|
||||
len(chunk.candidates) == 1
|
||||
and chunk.candidates[0].content.parts[0].text.startswith("[Invariant]")
|
||||
and chunk.prompt_feedback is not None
|
||||
and "BlockedReason.SAFETY" in str(chunk.prompt_feedback)
|
||||
)
|
||||
|
||||
|
||||
def assert_is_streamed_refusal(response, expected_message_components: list[str]):
|
||||
"""
|
||||
Validates that the streamed response contains a refusal at the end (or as only message).
|
||||
"""
|
||||
num_chunks = 0
|
||||
for c in response:
|
||||
num_chunks += 1
|
||||
|
||||
assert num_chunks >= 1, "Expected at least one chunk"
|
||||
# last chunk must be a refusal
|
||||
assert is_refusal(c)
|
||||
|
||||
for emc in expected_message_components:
|
||||
assert (
|
||||
emc in c.model_dump_json()
|
||||
), f"Expected message component {emc} not found in refusal message: {c.model_dump_json()}"
|
||||
|
||||
@@ -244,3 +244,108 @@ async def test_tool_call_guardrail_from_file(
|
||||
== "get_capital is called with Germany as argument"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_input_from_guardrail_from_file(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test the message content guardrail."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
|
||||
request = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
else:
|
||||
with pytest.raises(APIError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for _ in chat_response:
|
||||
pass
|
||||
assert (
|
||||
"[Invariant] The request did not pass the guardrails"
|
||||
in exc_info.value.message
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value.body
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
# in case of input guardrailing, the pushed trace will not contain a response
|
||||
assert len(trace["messages"]) == 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": "Tell me more about Fight Club.",
|
||||
}
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"]
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
@@ -13,4 +13,10 @@ raise "Dublin detected in the response" if:
|
||||
raise "get_capital is called with Germany as argument" if:
|
||||
(call: ToolCall)
|
||||
call is tool:get_capital
|
||||
call.function.arguments["country_name"] == "Germany"
|
||||
call.function.arguments["country_name"] == "Germany"
|
||||
|
||||
# For input guardrailing specifically
|
||||
raise "Users must not mention the magic phrase 'Fight Club'" if:
|
||||
(msg: Message)
|
||||
msg.role == "user"
|
||||
"Fight Club" in msg.content
|
||||
Reference in New Issue
Block a user