From 7c0bb957fbef82956b5a47d3d207eec0e1f21bb8 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Mon, 31 Mar 2025 14:13:58 +0200 Subject: [PATCH] Pipelined Guardrails (#32) * initial draft: pipelined guardrails * documentation on stream instrumentation * more comments * fix: return earlier * non-streaming case * handle non-streaming case * fix more cases * simplify request instrumentation * improve comments * fix import issues * extend tests for input guardrailing * anthropic integration of pipelined and pre-guardrailing * fix gemini streamed refusal --- .env | 2 +- example_policy.gr | 7 + gateway/common/config_manager.py | 4 +- gateway/integrations/guardails.py | 132 ----- gateway/integrations/guardrails.py | 358 ++++++++++++++ gateway/routes/anthropic.py | 394 +++++++++++---- gateway/routes/gemini.py | 451 ++++++++++++----- gateway/routes/open_ai.py | 460 +++++++++++++----- gateway/run.sh | 2 +- .../guardrails/test_guardrails_anthropic.py | 94 ++++ .../guardrails/test_guardrails_gemini.py | 137 +++++- .../guardrails/test_guardrails_open_ai.py | 105 ++++ .../guardrails/find_capital_guardrails.py | 8 +- 13 files changed, 1659 insertions(+), 495 deletions(-) create mode 100644 example_policy.gr delete mode 100644 gateway/integrations/guardails.py create mode 100644 gateway/integrations/guardrails.py diff --git a/.env b/.env index 410fc9a..af1265e 100644 --- a/.env +++ b/.env @@ -3,4 +3,4 @@ # If you want to push to a local instance of explorer, then specify the app-api docker container name like: # http://:8000 to push to the local explorer instance. INVARIANT_API_URL=https://explorer.invariantlabs.ai -GUADRAILS_API_URL=https://guardrail.invariantnet.com +GUADRAILS_API_URL=https://explorer.invariantlabs.ai diff --git a/example_policy.gr b/example_policy.gr new file mode 100644 index 0000000..cb378ae --- /dev/null +++ b/example_policy.gr @@ -0,0 +1,7 @@ +from invariant.detectors import prompt_injection + +raise "Don't say 'Hello'" if: + (msg: Message) + msg.role == "user" + prompt_injection(msg.content) + # "Hello" in msg.content \ No newline at end of file diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 57a1d89..3e62b55 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -4,8 +4,6 @@ import asyncio import os import threading -from integrations.guardails import _preload - from httpx import HTTPStatusError @@ -20,6 +18,8 @@ class GatewayConfig: Loads the guardrails from the file specified in GUARDRAILS_FILE_PATH. Returns the guardrails file content as a string. """ + from integrations.guardrails import _preload + guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "") if not guardrails_file: print("[warning: GUARDRAILS_FILE_PATH is not set. Using empty guardrails]") diff --git a/gateway/integrations/guardails.py b/gateway/integrations/guardails.py deleted file mode 100644 index 4c571fc..0000000 --- a/gateway/integrations/guardails.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Utility functions for Guardrails execution.""" - -import asyncio -import os -import time -from typing import Any, Dict, List -from functools import wraps - -import httpx - -DEFAULT_API_URL = "https://guardrail.invariantnet.com" - - -# Timestamps of last API calls per guardrails string -_guardrails_cache = {} -# Locks per guardrails string -_guardrails_locks = {} - - -def rate_limit(expiration_time: int = 3600): - """ - Decorator to limit API calls to once per expiration_time seconds - per unique guardrails string. - - Args: - expiration_time (int): Time in seconds to cache the guardrails. - """ - - def decorator(func): - @wraps(func) - async def wrapper(guardrails: str, *args, **kwargs): - now = time.time() - - # Get or create a per-guardrail lock - if guardrails not in _guardrails_locks: - _guardrails_locks[guardrails] = asyncio.Lock() - guardrail_lock = _guardrails_locks[guardrails] - - async with guardrail_lock: - last_called = _guardrails_cache.get(guardrails) - - if last_called and (now - last_called < expiration_time): - # Skipping API call: Guardrails '{guardrails}' already - # preloaded within expiration_time - return - - # Update cache timestamp - _guardrails_cache[guardrails] = now - - try: - await func(guardrails, *args, **kwargs) - finally: - _guardrails_locks.pop(guardrails, None) - - return wrapper - - return decorator - - -@rate_limit(3600) # Don't preload the same guardrails string more than once per hour -async def _preload(guardrails: str, invariant_authorization: str) -> None: - """ - Calls the Guardrails API to preload the provided policy for faster checking later. - - Args: - guardrails (str): The guardrails to preload. - invariant_authorization (str): Value of the - invariant-authorization header. - """ - async with httpx.AsyncClient() as client: - url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") - result = await client.post( - f"{url}/api/v1/policy/load", - json={"policy": guardrails}, - headers={ - "Authorization": invariant_authorization, - "Accept": "application/json", - }, - ) - result.raise_for_status() - - -async def preload_guardrails(context: "RequestContextData") -> None: - """ - Preloads the guardrails for faster checking later. - - Args: - context: RequestContextData object. - """ - if not context.config or not context.config.guardrails: - return - - try: - task = asyncio.create_task( - _preload(context.config.guardrails, context.invariant_authorization) - ) - asyncio.shield(task) - except Exception as e: - print(f"Error scheduling preload_guardrails task: {e}") - - -async def check_guardrails( - messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str -) -> Dict[str, Any]: - """ - Checks guardrails on the list of messages. - - Args: - messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. - guardrails (str): The guardrails to check against. - invariant_authorization (str): Value of the - invariant-authorization header. - - Returns: - Dict: Response containing guardrail check results. - """ - async with httpx.AsyncClient() as client: - url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") - try: - result = await client.post( - f"{url}/api/v1/policy/check", - json={"messages": messages, "policy": guardrails}, - headers={ - "Authorization": invariant_authorization, - "Accept": "application/json", - }, - ) - print(f"Guardrail check response: {result.json()}") - return result.json() - except Exception as e: - print(f"Failed to verify guardrails: {e}") - return {"error": str(e)} diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py new file mode 100644 index 0000000..b0f3601 --- /dev/null +++ b/gateway/integrations/guardrails.py @@ -0,0 +1,358 @@ +"""Utility functions for Guardrails execution.""" + +import asyncio +import os +import time +from typing import Any, Dict, List +from functools import wraps + +import httpx +from common.request_context_data import RequestContextData + +DEFAULT_API_URL = "https://explorer.invariantlabs.ai" + + +# Timestamps of last API calls per guardrails string +_guardrails_cache = {} +# Locks per guardrails string +_guardrails_locks = {} + + +def rate_limit(expiration_time: int = 3600): + """ + Decorator to limit API calls to once per expiration_time seconds + per unique guardrails string. + + Args: + expiration_time (int): Time in seconds to cache the guardrails. + """ + + def decorator(func): + @wraps(func) + async def wrapper(guardrails: str, *args, **kwargs): + now = time.time() + + # Get or create a per-guardrail lock + if guardrails not in _guardrails_locks: + _guardrails_locks[guardrails] = asyncio.Lock() + guardrail_lock = _guardrails_locks[guardrails] + + async with guardrail_lock: + last_called = _guardrails_cache.get(guardrails) + + if last_called and (now - last_called < expiration_time): + # Skipping API call: Guardrails '{guardrails}' already + # preloaded within expiration_time + return + + # Update cache timestamp + _guardrails_cache[guardrails] = now + + try: + await func(guardrails, *args, **kwargs) + finally: + _guardrails_locks.pop(guardrails, None) + + return wrapper + + return decorator + + +@rate_limit(3600) # Don't preload the same guardrails string more than once per hour +async def _preload(guardrails: str, invariant_authorization: str) -> None: + """ + Calls the Guardrails API to preload the provided policy for faster checking later. + + Args: + guardrails (str): The guardrails to preload. + invariant_authorization (str): Value of the + invariant-authorization header. + """ + async with httpx.AsyncClient() as client: + url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") + result = await client.post( + f"{url}/api/v1/policy/load", + json={"policy": guardrails}, + headers={ + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + ) + result.raise_for_status() + + +async def preload_guardrails(context: "RequestContextData") -> None: + """ + Preloads the guardrails for faster checking later. + + Args: + context: RequestContextData object. + """ + if not context.config or not context.config.guardrails: + return + + try: + task = asyncio.create_task( + _preload(context.config.guardrails, context.invariant_authorization) + ) + asyncio.shield(task) + except Exception as e: + print(f"Error scheduling preload_guardrails task: {e}") + + +class ExtraItem: + """ + Return this class in a instrumented stream callback, to yield an extra item in the resulting stream. + """ + + def __init__(self, value, end_of_stream=False): + self.value = value + self.end_of_stream = end_of_stream + + def __str__(self): + return f"" + + +class Replacement(ExtraItem): + """ + Like ExtraItem, but used to replace the full request result in case of 'InstrumentedResponse'. + """ + + def __init__(self, value): + super().__init__(value, end_of_stream=True) + + def __str__(self): + return f"" + + +class InstrumentedStreamingResponse: + def __init__(self): + # request statistics + self.stat_token_times = [] + self.stat_before_time = None + self.stat_after_time = None + + self.stat_first_item_time = None + + async def on_chunk(self, chunk: Any) -> ExtraItem | None: + """ + This called will be called on every chunk (async). + """ + pass + + async def on_start(self) -> ExtraItem | None: + """ + Decorator to register a listener for start events. + """ + pass + + async def on_end(self) -> ExtraItem | None: + """ + Decorator to register a listener for end events. + """ + pass + + async def event_generator(self): + """ + Streams the async iterable and invokes all instrumented hooks. + + Args: + async_iterable: An async iterable to stream. + + Yields: + The streamed data. + """ + raise NotImplementedError("This method should be implemented in a subclass.") + + async def instrumented_event_generator(self): + """ + Streams the async iterable and invokes all instrumented hooks. + + Args: + async_iterable: An async iterable to stream. + + Yields: + The streamed data. + """ + try: + start = time.time() + + # schedule on_start which can be run concurrently + start_task = asyncio.create_task(self.on_start(), name="instrumentor:start") + + # create async iterator from async_iterable + aiterable = aiter(self.event_generator()) + + # [STAT] capture start time of first item + start_first_item_request = time.time() + + # waits for first item of the iterable + async def wait_for_first_item(): + nonlocal start_first_item_request, aiterable + + r = await aiterable.__anext__() + if self.stat_first_item_time is None: + # [STAT] capture time to first item + self.stat_first_item_time = time.time() - start_first_item_request + return r + + next_item_task = asyncio.create_task( + wait_for_first_item(), name="instrumentor:next:first" + ) + + # check if 'start_task' yields an extra item + if extra_item := await start_task: + # yield extra value before any real items + yield extra_item.value + # stop the stream if end_of_stream is True + if extra_item.end_of_stream: + # if first item is already available + if not next_item_task.done(): + # cancel the task + next_item_task.cancel() + # [STAT] capture time to first item to be now +0.01 + if self.stat_first_item_time is None: + self.stat_first_item_time = ( + time.time() - start_first_item_request + ) + 0.01 + # don't wait for the first item if end_of stream is True + return + + # [STAT] capture before time stamp + self.stat_before_time = time.time() - start + + while True: + # wait for first item + try: + item = await next_item_task + except StopAsyncIteration: + break + + # schedule next item + next_item_task = asyncio.create_task( + aiterable.__anext__(), name="instrumentor:next" + ) + + # [STAT] capture token time stamp + if len(self.stat_token_times) == 0: + self.stat_token_times.append(time.time() - start) + else: + self.stat_token_times.append( + time.time() - start - sum(self.stat_token_times) + ) + + if extra_item := await self.on_chunk(item): + yield extra_item.value + # if end_of_stream is True, stop the stream + if extra_item.end_of_stream: + # cancel next task + next_item_task.cancel() + return + + # yield item + yield item + + # run on_end, before closing the stream (may yield an extra value) + if extra_item := await self.on_end(): + # yield extra value before any real items + yield extra_item.value + # we ignore end_of_stream here, because we are already at the end + + # [STAT] capture after time stamp + self.stat_after_time = time.time() - start + finally: + # [STAT] end all open intervals if not already closed + if self.stat_after_time is None: + self.stat_before_time = time.time() - start + if self.stat_after_time is None: + self.stat_after_time = 0 + if self.stat_first_item_time is None: + self.stat_first_item_time = 0 + + # print statistics + token_times_5_decimale = str([f"{x:.5f}" for x in self.stat_token_times]) + print( + f"[STATS]\n [token times: {token_times_5_decimale} ({len(self.stat_token_times)})]" + ) + print(f" [before: {self.stat_before_time:.2f}s] ") + print(f" [time-to-first-item: {self.stat_first_item_time:.2f}s]") + print( + f" [zero-latency: {' TRUE' if self.stat_before_time < self.stat_first_item_time else 'FALSE'}]" + ) + print( + f" [extra-latency: {self.stat_before_time - self.stat_first_item_time:.2f}s]" + ) + print(f" [after: {self.stat_after_time:.2f}s]") + if len(self.stat_token_times) > 0: + print( + f" [average token time: {sum(self.stat_token_times) / len(self.stat_token_times):.2f}s]" + ) + print(f" [total: {time.time() - start:.2f}s]") + + +class InstrumentedResponse(InstrumentedStreamingResponse): + """ + A class to instrument an async request with hooks for concurrent + pre-processing and post-processing (input and output guardrailing). + """ + + async def event_generator(self): + """ + We implement the 'event_generator' as a single item stream, + where the item is the full result of the request. + """ + yield await self.request() + + async def request(self): + """ + This method should be implemented in a subclass to perform the actual request. + """ + raise NotImplementedError("This method should be implemented in a subclass.") + + async def instrumented_request(self): + """ + Returns the 'Response' object of the request, after applying all instrumented hooks. + """ + results = [r async for r in self.instrumented_event_generator()] + assert len(results) >= 1, "InstrumentedResponse must yield at least one item" + + # we return the last item, in case the end callback yields an extra item. Then, + # don't return the actual result but the 'end' result, e.g. for output guardrailing. + return results[-1] + + +async def check_guardrails( + messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str +) -> Dict[str, Any]: + """ + Checks guardrails on the list of messages. + + Args: + messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. + guardrails (str): The guardrails to check against. + invariant_authorization (str): Value of the + invariant-authorization header. + + Returns: + Dict: Response containing guardrail check results. + """ + async with httpx.AsyncClient() as client: + url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") + try: + result = await client.post( + f"{url}/api/v1/policy/check", + json={"messages": messages, "policy": guardrails}, + headers={ + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + ) + if not result.is_success: + raise Exception( + f"Guardrails check failed: {result.status_code} - {result.text}" + ) + print(f"Guardrail check response: {result.json()}") + return result.json() + except Exception as e: + print(f"Failed to verify guardrails: {e}") + return {"error": str(e)} diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 01004c6..2f3e243 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -5,6 +5,7 @@ import json from typing import Any, Optional import httpx +from regex import R from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse @@ -18,7 +19,14 @@ from converters.anthropic_to_invariant import ( ) from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData -from integrations.guardails import check_guardrails, preload_guardrails +from integrations.guardrails import ( + ExtraItem, + InstrumentedResponse, + InstrumentedStreamingResponse, + Replacement, + check_guardrails, + preload_guardrails, +) gateway = APIRouter() @@ -85,8 +93,7 @@ async def anthropic_v1_messages_gateway( if request_json.get("stream"): return await handle_streaming_response(context, client, anthropic_request) - response = await client.send(anthropic_request) - return await handle_non_streaming_response(context, response) + return await handle_non_streaming_response(context, client, anthropic_request) def create_metadata( @@ -110,7 +117,8 @@ def combine_request_and_response_messages( {"role": "system", "content": context.request_json.get("system")} ) messages.extend(context.request_json.get("messages", [])) - messages.append(json_response) + if len(json_response) > 0: + messages.append(json_response) return messages @@ -154,56 +162,282 @@ async def push_to_explorer( ) +class InstrumentedAnthropicResponse(InstrumentedResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + anthropic_request: httpx.Request, + ): + super().__init__() + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.anthropic_request: httpx.Request = anthropic_request + + # response data + self.response: Optional[httpx.Response] = None + self.response_string: Optional[str] = None + self.json_response: Optional[dict[str, Any]] = None + + # guardrailing response (if any) + self.guardrails_execution_result = {} + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, {} + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # if we find something, we prevent the request from going through + # and return an error instead + return Replacement( + Response( + content=error_chunk, + status_code=400, + media_type="application/json", + headers={"content-type": "application/json"}, + ) + ) + + async def request(self): + self.response = await self.client.send(self.anthropic_request) + + try: + json_response = self.response.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=self.response.status_code, + detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}", + ) from e + if self.response.status_code != 200: + raise HTTPException( + status_code=self.response.status_code, + detail=json_response.get("error", "Unknown error from Anthropic"), + ) + + self.json_response = json_response + self.response_string = json.dumps(json_response) + + return self._make_response( + content=self.response_string, + status_code=self.response.status_code, + ) + + def _make_response(self, content: str, status_code: int): + """Creates a new Response object with the correct headers and content""" + assert self.response is not None, "response is None" + + updated_headers = self.response.headers.copy() + updated_headers.pop("Content-Length", None) + + return Response( + content=content, + status_code=status_code, + media_type="application/json", + headers=dict(updated_headers), + ) + + async def on_end(self): + """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" + # ensure the response data is available + assert self.response is not None, "response is None" + assert self.json_response is not None, "json_response is None" + assert self.response_string is not None, "response_string is None" + + if self.context.config and self.context.config.guardrails: + # Block on the guardrails check + guardrails_execution_result = await get_guardrails_check_result( + self.context, self.json_response + ) + if guardrails_execution_result.get("errors", []): + guardrail_response_string = json.dumps( + { + "error": "[Invariant] The response did not pass the guardrails", + "details": guardrails_execution_result, + } + ) + + # push to explorer (if configured) + if self.context.dataset_name: + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer( + self.context, + self.json_response, + guardrails_execution_result, + ) + ) + + return Replacement( + self._make_response( + content=guardrail_response_string, + status_code=400, + ) + ) + + # push to explorer (if configured) + if self.context.dataset_name: + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer( + self.context, self.json_response, guardrails_execution_result + ) + ) + + async def handle_non_streaming_response( context: RequestContextData, - response: httpx.Response, + client: httpx.AsyncClient, + anthropic_request: httpx.Request, ) -> Response: """Handles non-streaming Anthropic responses""" - try: - json_response = response.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=response.status_code, - detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}", - ) from e - if response.status_code != 200: - raise HTTPException( - status_code=response.status_code, - detail=json_response.get("error", "Unknown error from Anthropic"), - ) - - guardrails_execution_result = {} - response_string = json.dumps(json_response) - response_code = response.status_code - - if context.config and context.config.guardrails: - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, json_response - ) - if guardrails_execution_result.get("errors", []): - response_string = json.dumps( - { - "error": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - } - ) - response_code = 400 - if context.dataset_name: - # Push to Explorer - don't block on its response - asyncio.create_task( - push_to_explorer(context, json_response, guardrails_execution_result) - ) - - updated_headers = response.headers.copy() - updated_headers.pop("Content-Length", None) - return Response( - content=response_string, - status_code=response_code, - media_type="application/json", - headers=dict(updated_headers), + response = InstrumentedAnthropicResponse( + context=context, + client=client, + anthropic_request=anthropic_request, ) + return await response.instrumented_request() + + +class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + anthropic_request: httpx.Request, + ): + super().__init__() + + # request parameters + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.anthropic_request: httpx.Request = anthropic_request + + # response data + self.merged_response = {} + + # guardrailing response (if any) + self.guardrails_execution_result = {} + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return ExtraItem( + f"event: error\ndata: {error_chunk}\n\n".encode(), + end_of_stream=True, + ) + + async def event_generator(self): + """Actual streaming response generator""" + response = await self.client.send(self.anthropic_request, stream=True) + if response.status_code != 200: + error_content = await response.aread() + try: + error_json = json.loads(error_content) + error_detail = error_json.get("error", "Unknown error from Anthropic") + except json.JSONDecodeError: + error_detail = { + "error": "Failed to decode error response from Anthropic" + } + raise HTTPException(status_code=response.status_code, detail=error_detail) + + # iterate over the response stream + async for chunk in response.aiter_bytes(): + yield chunk + + async def on_chunk(self, chunk): + decoded_chunk = chunk.decode().strip() + if not decoded_chunk: + return + + # process chunk and extend the merged_response + process_chunk(decoded_chunk, self.merged_response) + + # on last stream chunk, run output guardrails + if ( + "event: message_stop" in decoded_chunk + and self.context.config + and self.context.config.guardrails + ): + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "type": "error", + "error": { + "message": "[Invariant] The response did not pass the guardrails", + "details": self.guardrails_execution_result, + }, + } + ) + + # yield an extra error chunk (without preventing the original chunk to go through after, + # so client gets the proper message_stop event still) + return ExtraItem( + value=f"event: error\ndata: {error_chunk}\n\n".encode() + ) + + async def on_end(self): + """on_end: send full merged response to the exploree (if configured)""" + # don't block on the response from explorer (.create_task) + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + async def handle_streaming_response( context: RequestContextData, @@ -211,63 +445,15 @@ async def handle_streaming_response( anthropic_request: httpx.Request, ) -> StreamingResponse: """Handles streaming Anthropic responses""" - merged_response = {} + response = InstrumentedAnthropicStreamingResposne( + context=context, + client=client, + anthropic_request=anthropic_request, + ) - response = await client.send(anthropic_request, stream=True) - if response.status_code != 200: - error_content = await response.aread() - try: - error_json = json.loads(error_content) - error_detail = error_json.get("error", "Unknown error from Anthropic") - except json.JSONDecodeError: - error_detail = {"error": "Failed to decode error response from Anthropic"} - raise HTTPException(status_code=response.status_code, detail=error_detail) - - async def event_generator() -> Any: - async for chunk in response.aiter_bytes(): - decoded_chunk = chunk.decode().strip() - if not decoded_chunk: - continue - process_chunk(decoded_chunk, merged_response) - if ( - "event: message_stop" in decoded_chunk - and context.config - and context.config.guardrails - ): - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "type": "error", - "error": { - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - }, - } - ) - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - yield f"event: error\ndata: {error_chunk}\n\n".encode() - return - yield chunk - - if context.dataset_name: - # Push to Explorer - don't block on the response - asyncio.create_task(push_to_explorer(context, merged_response)) - - generator = event_generator() - - return StreamingResponse(generator, media_type="text/event-stream") + return StreamingResponse( + response.instrumented_event_generator(), media_type="text/event-stream" + ) def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None: diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 4042b4e..59d0874 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any, Optional +from typing import Any, Literal, Optional import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -15,8 +15,16 @@ from common.constants import ( from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData from converters.gemini_to_invariant import convert_request, convert_response +from integrations.guardrails import ( + ExtraItem, + InstrumentedResponse, + InstrumentedStreamingResponse, + Replacement, + preload_guardrails, + check_guardrails, +) from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardails import check_guardrails, preload_guardrails +from integrations.guardrails import check_guardrails, preload_guardrails gateway = APIRouter() @@ -82,13 +90,169 @@ async def gemini_generate_content_gateway( client, gemini_request, ) - response = await client.send(gemini_request) return await handle_non_streaming_response( context, - response, + client, + gemini_request, ) +class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + gemini_request: httpx.Request, + ): + super().__init__() + + # request data + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.gemini_request: httpx.Request = gemini_request + + # Store the progressively merged response + self.merged_response = { + "candidates": [{"content": {"parts": []}, "finishReason": None}] + } + + # guardrailing execution result (if any) + self.guardrails_execution_result: Optional[dict[str, Any]] = None + + def make_refusal( + self, + location: Literal["request", "response"], + guardrails_execution_result: dict[str, Any], + ) -> dict: + return { + "candidates": [ + { + "content": { + "parts": [ + { + "text": f"[Invariant] The {location} did not pass the guardrails", + } + ], + } + } + ], + "error": { + "code": 400, + "message": f"[Invariant] The {location} did not pass the guardrails", + "details": guardrails_execution_result, + "status": "INVARIANT_GUARDRAILS_VIOLATION", + }, + "promptFeedback": { + "blockReason": "SAFETY", + "block_reason_message": f"[Invariant] The {location} did not pass the guardrails: " + + json.dumps(guardrails_execution_result), + "safetyRatings": [ + { + "category": "HARM_CATEGORY_UNSPECIFIED", + "probability": "HIGH", + "blocked": True, + } + ], + }, + } + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, {} + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + self.make_refusal("request", self.guardrails_execution_result) + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return ExtraItem( + f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True + ) + + async def event_generator(self): + response = await self.client.send(self.gemini_request, stream=True) + + if response.status_code != 200: + error_content = await response.aread() + try: + error_json = json.loads(error_content.decode("utf-8")) + error_detail = error_json.get("error", "Unknown error from Gemini API") + except json.JSONDecodeError: + error_detail = {"error": "Failed to parse Gemini error response"} + raise HTTPException(status_code=response.status_code, detail=error_detail) + + async for chunk in response.aiter_bytes(): + yield chunk + + async def on_chunk(self, chunk): + chunk_text = chunk.decode().strip() + if not chunk_text: + return + + # Parse and update merged_response incrementally + process_chunk_text(self.merged_response, chunk_text) + + # runs on the last stream item + if ( + self.merged_response.get("candidates", []) + and self.merged_response.get("candidates")[0].get("finishReason", "") + and self.context.config + and self.context.config.guardrails + ): + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + self.make_refusal("response", self.guardrails_execution_result) + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + return ExtraItem( + value=f"data: {error_chunk}\r\n\r\n".encode(), + # for Gemini we have to end the stream prematurely, as the client SDK + # will not stop streaming when it encounters an error + end_of_stream=True, + ) + + async def on_end(self): + """Runs when the stream ends.""" + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + async def stream_response( context: RequestContextData, client: httpx.AsyncClient, @@ -96,76 +260,21 @@ async def stream_response( ) -> Response: """Handles streaming the Gemini response to the client""" - response = await client.send(gemini_request, stream=True) - if response.status_code != 200: - error_content = await response.aread() - try: - error_json = json.loads(error_content.decode("utf-8")) - error_detail = error_json.get("error", "Unknown error from Gemini API") - except json.JSONDecodeError: - error_detail = {"error": "Failed to parse Gemini error response"} - raise HTTPException(status_code=response.status_code, detail=error_detail) + response = InstrumentedStreamingGeminiResponse( + context=context, + client=client, + gemini_request=gemini_request, + ) - async def event_generator() -> Any: - # Store the progressively merged response - merged_response = { - "candidates": [{"content": {"parts": []}, "finishReason": None}] - } - - async for chunk in response.aiter_bytes(): - chunk_text = chunk.decode().strip() - if not chunk_text: - continue - - # Parse and update merged_response incrementally - process_chunk_text(merged_response, chunk_text) - - if ( - merged_response.get("candidates", []) - and merged_response.get("candidates")[0].get("finishReason", "") - and context.config - and context.config.guardrails - ): - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "error": { - "code": 400, - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - "status": "INVARIANT_GUARDRAILS_VIOLATION", - }, - } - ) - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - yield f"data: {error_chunk}\n\n".encode() - return - - # Yield chunk immediately to the client + async def event_generator(): + async for chunk in response.instrumented_event_generator(): yield chunk + print("chunk", chunk) - if context.dataset_name: - # Push to Explorer - don't block on the response - asyncio.create_task( - push_to_explorer( - context, - merged_response, - ) - ) - - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + ) def process_chunk_text( @@ -281,53 +390,165 @@ async def push_to_explorer( ) +class InstrumentedGeminiResponse(InstrumentedResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + gemini_request: httpx.Request, + ): + super().__init__() + + # request data + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.gemini_request: httpx.Request = gemini_request + + # response data + self.response: Optional[httpx.Response] = None + self.response_json: Optional[dict[str, Any]] = None + + # guardrails execution result (if any) + self.guardrails_execution_result: Optional[dict[str, Any]] = None + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, {} + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "code": 400, + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + "status": "INVARIANT_GUARDRAILS_VIOLATION", + }, + "prompt_feedback": { + "blockReason": "SAFETY", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_UNSPECIFIED", + "probability": 0.0, + "blocked": True, + } + ], + }, + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return Replacement( + Response( + content=error_chunk, + status_code=400, + media_type="application/json", + headers={ + "Content-Type": "application/json", + }, + ) + ) + + async def request(self): + self.response = await self.client.send(self.gemini_request) + + response_string = self.response.text + response_code = self.response.status_code + + try: + self.response_json = self.response.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=self.response.status_code, + detail="Invalid JSON response received from Gemini API", + ) from e + if self.response.status_code != 200: + raise HTTPException( + status_code=self.response.status_code, + detail=self.response_json.get("error", "Unknown error from Gemini API"), + ) + + return Response( + content=response_string, + status_code=response_code, + media_type="application/json", + headers=dict(self.response.headers), + ) + + async def on_end(self): + response_string = json.dumps(self.response_json) + response_code = self.response.status_code + + if self.context.config and self.context.config.guardrails: + # Block on the guardrails check + guardrails_execution_result = await get_guardrails_check_result( + self.context, self.response_json + ) + if guardrails_execution_result.get("errors", []): + response_string = json.dumps( + { + "error": { + "code": 400, + "message": "[Invariant] The response did not pass the guardrails", + "details": guardrails_execution_result, + "status": "INVARIANT_GUARDRAILS_VIOLATION", + }, + } + ) + response_code = 400 + + if self.context.dataset_name: + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer( + self.context, + self.response_json, + guardrails_execution_result, + ) + ) + + return Replacement( + Response( + content=response_string, + status_code=response_code, + media_type="application/json", + headers=dict(self.response.headers), + ) + ) + + # Otherwise, also push to Explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, self.response_json, guardrails_execution_result + ) + ) + + async def handle_non_streaming_response( context: RequestContextData, - response: httpx.Response, + client: httpx.AsyncClient, + gemini_request: httpx.Request, ) -> Response: """Handles non-streaming Gemini responses""" - try: - response_json = response.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=response.status_code, - detail="Invalid JSON response received from Gemini API", - ) from e - if response.status_code != 200: - raise HTTPException( - status_code=response.status_code, - detail=response_json.get("error", "Unknown error from Gemini API"), - ) - guardrails_execution_result = {} - response_string = json.dumps(response_json) - response_code = response.status_code - if context.config and context.config.guardrails: - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, response_json - ) - if guardrails_execution_result.get("errors", []): - response_string = json.dumps( - { - "error": { - "code": 400, - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - "status": "INVARIANT_GUARDRAILS_VIOLATION", - }, - } - ) - response_code = 400 - if context.dataset_name: - # Push to Explorer - don't block on its response - asyncio.create_task( - push_to_explorer(context, response_json, guardrails_execution_result) - ) - - return Response( - content=response_string, - status_code=response_code, - media_type="application/json", - headers=dict(response.headers), + response = InstrumentedGeminiResponse( + context=context, + client=client, + gemini_request=gemini_request, ) + + return await response.instrumented_request() diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 897dbb8..6ef3808 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -13,7 +13,13 @@ from common.constants import ( IGNORED_HEADERS, ) from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardails import check_guardrails, preload_guardrails +from integrations.guardrails import ( + ExtraItem, + InstrumentedResponse, + InstrumentedStreamingResponse, + check_guardrails, + preload_guardrails, +) from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData @@ -74,19 +80,159 @@ async def openai_chat_completions_gateway( asyncio.create_task(preload_guardrails(context)) if request_json.get("stream", False): - return await stream_response( + return await handle_stream_response( context, client, open_ai_request, ) - response = await client.send(open_ai_request) - return await handle_non_streaming_response( - context, - response, - ) + + return await handle_non_stream_response(context, client, open_ai_request) -async def stream_response( +class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): + """ + Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + """ + + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, + ): + super().__init__() + + # request parameters + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.open_ai_request: httpx.Request = open_ai_request + + # guardrailing output (if any) + self.guardrails_execution_result: Optional[dict] = None + + # merged_response will be updated with the data from the chunks in the stream + # At the end of the stream, this will be sent to the explorer + self.merged_response = { + "id": None, + "object": "chat.completion", + "created": None, + "model": None, + "choices": [], + "usage": None, + } + + # Each chunk in the stream contains a list called "choices" each entry in the list + # has an index. + # A choice has a field called "delta" which may contain a list called "tool_calls". + # Maps the choice index in the stream to the index in the merged_response["choices"] list + self.choice_mapping_by_index = {} + # Combines the choice index and tool call index to uniquely identify a tool call + self.tool_call_mapping_by_index = {} + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return ExtraItem( + f"data: {error_chunk}\n\n".encode(), + end_of_stream=True, + ) + + async def on_chunk(self, chunk): + # process and check each chunk + chunk_text = chunk.decode().strip() + if not chunk_text: + return + + # Process the chunk + # This will update merged_response with the data from the chunk + process_chunk_text( + chunk_text, + self.merged_response, + self.choice_mapping_by_index, + self.tool_call_mapping_by_index, + ) + + # check guardrails at the end of the stream (on the '[DONE]' SSE chunk.) + if ( + "data: [DONE]" in chunk_text + and self.context.config + and self.context.config.guardrails + ): + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The response did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # yield an extra error chunk (without preventing the original chunk to go through after) + return ExtraItem(f"data: {error_chunk}\n\n".encode()) + + # push will happen in on_end + + async def on_end(self): + """Sends full merged response to the exploree.""" + # don't block on the response from explorer (.create_task) + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, self.merged_response, self.guardrails_execution_result + ) + ) + + async def event_generator(self): + """ + Actual OpenAI stream response. + """ + + response = await self.client.send(self.open_ai_request, stream=True) + if response.status_code != 200: + error_content = await response.aread() + try: + error_json = json.loads(error_content.decode("utf-8")) + error_detail = error_json.get("error", "Unknown error from OpenAI API") + except json.JSONDecodeError: + error_detail = {"error": "Failed to parse OpenAI error response"} + raise HTTPException(status_code=response.status_code, detail=error_detail) + + # stream out chunks + async for chunk in response.aiter_bytes(): + yield chunk + + +async def handle_stream_response( context: RequestContextData, client: httpx.AsyncClient, open_ai_request: httpx.Request, @@ -98,89 +244,15 @@ async def stream_response( It is sent to the Invariant Explorer at the end of the stream """ - response = await client.send(open_ai_request, stream=True) - if response.status_code != 200: - error_content = await response.aread() - try: - error_json = json.loads(error_content.decode("utf-8")) - error_detail = error_json.get("error", "Unknown error from OpenAI API") - except json.JSONDecodeError: - error_detail = {"error": "Failed to parse OpenAI error response"} - raise HTTPException(status_code=response.status_code, detail=error_detail) + response = InstrumentedOpenAIStreamResponse( + context, + client, + open_ai_request, + ) - async def event_generator() -> Any: - # merged_response will be updated with the data from the chunks in the stream - # At the end of the stream, this will be sent to the explorer - merged_response = { - "id": None, - "object": "chat.completion", - "created": None, - "model": None, - "choices": [], - "usage": None, - } - # Each chunk in the stream contains a list called "choices" each entry in the list - # has an index. - # A choice has a field called "delta" which may contain a list called "tool_calls". - # Maps the choice index in the stream to the index in the merged_response["choices"] list - choice_mapping_by_index = {} - # Combines the choice index and tool call index to uniquely identify a tool call - tool_call_mapping_by_index = {} - - async for chunk in response.aiter_bytes(): - chunk_text = chunk.decode().strip() - if not chunk_text: - continue - - # Process the chunk - # This will update merged_response with the data from the chunk - process_chunk_text( - chunk_text, - merged_response, - choice_mapping_by_index, - tool_call_mapping_by_index, - ) - - # Check guardrails on the last chunk. - if ( - "data: [DONE]" in chunk_text - and context.config - and context.config.guardrails - ): - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "error": { - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - } - } - ) - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - yield f"data: {error_chunk}\n\n".encode() - return - - # Yield chunk to the client - yield chunk - - # Send full merged response to the explorer - # Don't block on the response from explorer - if context.dataset_name: - asyncio.create_task(push_to_explorer(context, merged_response)) - - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + response.instrumented_event_generator(), media_type="text/event-stream" + ) def initialize_merged_response() -> dict[str, Any]: @@ -329,7 +401,7 @@ def create_metadata( { key: value for key, value in merged_response.items() - if key in ("usage", "model") + if key in ("usage", "model") and merged_response.get(key) is not None } ) return metadata @@ -364,11 +436,13 @@ async def push_to_explorer( async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContextData, json_response: dict[str, Any] | None = None ) -> dict[str, Any]: """Get the guardrails check result""" messages = list(context.request_json.get("messages", [])) - messages += [choice["message"] for choice in json_response.get("choices", [])] + + if json_response is not None: + messages += [choice["message"] for choice in json_response.get("choices", [])] # Block on the guardrails check guardrails_execution_result = await check_guardrails( @@ -379,49 +453,165 @@ async def get_guardrails_check_result( return guardrails_execution_result -async def handle_non_streaming_response( - context: RequestContextData, response: httpx.Response +class InstrumentedOpenAIResponse(InstrumentedResponse): + """ + Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + """ + + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, + ): + super().__init__() + + # request parameters + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.open_ai_request: httpx.Request = open_ai_request + + # request outputs + self.response: Optional[httpx.Response] = None + self.json_response: Optional[dict[str, Any]] = None + + # guardrailing output (if any) + self.guardrails_execution_result: Optional[dict] = None + + async def on_start(self): + """Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)""" + if self.context.config and self.context.config.guardrails: + # block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context + ) + if self.guardrails_execution_result.get("errors", []): + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # replace the response with the error message + return ExtraItem( + Response( + content=json.dumps( + { + "error": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + ), + status_code=400, + media_type="application/json", + ), + end_of_stream=True, + ) + + async def request(self): + """Actual OpenAI request.""" + self.response = await self.client.send(self.open_ai_request) + + try: + self.json_response = self.response.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=self.response.status_code, + detail="Invalid JSON response received from OpenAI API", + ) from e + if self.response.status_code != 200: + raise HTTPException( + status_code=self.response.status_code, + detail=self.json_response.get("error", "Unknown error from OpenAI API"), + ) + + response_string = json.dumps(self.json_response) + response_code = self.response.status_code + + return Response( + content=response_string, + status_code=response_code, + media_type="application/json", + headers=dict(self.response.headers), + ) + + async def on_end(self): + """Postprocesses the OpenAI response and potentially replace it with a guardrails error.""" + + # these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed) + # nevertheless, we check for them to avoid any potential issues + assert ( + self.response is not None + ), "on_end called before 'self.response' was available" + assert ( + self.json_response is not None + ), "on_end called before 'self.json_response' was available" + + # extract original response status code + response_code = self.response.status_code + + # if we have guardrails, check the response + if self.context.config and self.context.config.guardrails: + # run guardrails again, this time on request + response + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.json_response + ) + if self.guardrails_execution_result.get("errors", []): + response_string = json.dumps( + { + "error": "[Invariant] The response did not pass the guardrails", + "details": self.guardrails_execution_result, + } + ) + response_code = 400 + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.json_response, + self.guardrails_execution_result, + ) + ) + + # replace the response with the error message + return ExtraItem( + Response( + content=response_string, + status_code=response_code, + media_type="application/json", + ), + ) + + # Push annotated trace to the explorer in any case - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.json_response, + # include any guardrailing errors if available + self.guardrails_execution_result, + ) + ) + + +async def handle_non_stream_response( + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, ) -> Response: """Handles non-streaming OpenAI responses""" - try: - json_response = response.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=response.status_code, - detail="Invalid JSON response received from OpenAI API", - ) from e - if response.status_code != 200: - raise HTTPException( - status_code=response.status_code, - detail=json_response.get("error", "Unknown error from OpenAI API"), - ) - guardrails_execution_result = {} - response_string = json.dumps(json_response) - response_code = response.status_code - - if context.config and context.config.guardrails: - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, json_response - ) - if guardrails_execution_result.get("errors", []): - response_string = json.dumps( - { - "error": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - } - ) - response_code = 400 - if context.dataset_name: - # Push to Explorer - don't block on its response - asyncio.create_task( - push_to_explorer(context, json_response, guardrails_execution_result) - ) - - return Response( - content=response_string, - status_code=response_code, - media_type="application/json", - headers=dict(response.headers), + # # execute instrumented request + # return await instrumentor.execute(send_request()) + response = InstrumentedOpenAIResponse( + context, + client, + open_ai_request, ) + + return await response.instrumented_request() diff --git a/gateway/run.sh b/gateway/run.sh index ebc4b1b..4cee98c 100755 --- a/gateway/run.sh +++ b/gateway/run.sh @@ -15,7 +15,7 @@ UVICORN_PORT=${PORT:-8000} # using 'exec' belows ensures that signals like SIGTERM are passed to the child process # and not the shell script itself (important when running in a container) if [ "$DEV_MODE" = "true" ]; then - exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT --reload + exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT --reload --reload-dir /srv/resources --reload-dir /srv/gateway else exec uvicorn serve:app --host 0.0.0.0 --port $UVICORN_PORT fi \ No newline at end of file diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index 45ba1fd..173e5ab 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -238,3 +238,97 @@ async def test_tool_call_guardrail_from_file( == "get_capital is called with Germany as argument" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_input_from_guardrail_from_file( + explorer_api_url, gateway_url, do_stream, push_to_explorer +): + """Test input guardrail enforcement with Anthropic.""" + if not os.getenv("INVARIANT_API_KEY"): + pytest.fail("No INVARIANT_API_KEY set, failing") + + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + + client = Anthropic( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/anthropic", + ) + + request = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 100, + "messages": [{"role": "user", "content": "Tell me more about Fight Club."}], + } + + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + _ = client.messages.create(**request, stream=False) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value + ) + + else: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create(**request, stream=True) + for _ in chat_response: + pass + + assert ( + "[Invariant] The request did not pass the guardrails" + in exc_info.value.message + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value.body + ) + + if push_to_explorer: + time.sleep(2) + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + # in case of input guardrailing, the pushed trace will not contain a response + trace = trace_response.json() + assert len(trace["messages"]) == 1, "Only the user message should be present" + assert trace["messages"][0] == { + "role": "user", + "content": "Tell me more about Fight Club.", + } + + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + assert len(annotations) == 1 + assert ( + annotations[0]["content"] + == "Users must not mention the magic phrase 'Fight Club'" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + ) diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index dfc7845..c463186 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -63,8 +63,13 @@ async def test_message_content_guardrail_from_file( else: response = client.models.generate_content_stream(**request) - for chunk in response: - assert "Dublin" not in str(chunk) + assert_is_streamed_refusal( + response, + [ + "[Invariant] The response did not pass the guardrails", + "Dublin detected in the response", + ], + ) if push_to_explorer: # Wait for the trace to be saved @@ -172,8 +177,13 @@ async def test_tool_call_guardrail_from_file( **request, ) - for chunk in response: - assert "Madrid" not in str(chunk) + assert_is_streamed_refusal( + response, + [ + "[Invariant] The response did not pass the guardrails", + "get_capital is called with Germany as argument", + ], + ) if push_to_explorer: # Wait for the trace to be saved @@ -219,3 +229,122 @@ async def test_tool_call_guardrail_from_file( == "get_capital is called with Germany as argument" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_input_from_guardrail_from_file( + explorer_api_url, gateway_url, do_stream, push_to_explorer +): + """Test input guardrail enforcement with Gemini.""" + if not os.getenv("INVARIANT_API_KEY"): + pytest.fail("No INVARIANT_API_KEY set, failing") + + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + + client = genai.Client( + api_key=os.getenv("GEMINI_API_KEY"), + http_options={ + "headers": { + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/gemini", + }, + ) + + request = { + "model": "gemini-2.0-flash", + "contents": "Tell me more about Fight Club.", + "config": { + "maxOutputTokens": 200, + }, + } + + if not do_stream: + with pytest.raises(genai.errors.ClientError) as exc_info: + client.models.generate_content(**request) + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value + ) + + else: + response = client.models.generate_content_stream(**request) + + assert_is_streamed_refusal( + response, + [ + "[Invariant] The request did not pass the guardrails", + "Users must not mention the magic phrase 'Fight Club'", + ], + ) + + if push_to_explorer: + time.sleep(2) + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 1 + assert trace["messages"][0] == { + "role": "user", + "content": [{"type": "text", "text": "Tell me more about Fight Club."}], + } + + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] + == "Users must not mention the magic phrase 'Fight Club'" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + ) + + +def is_refusal(chunk): + return ( + len(chunk.candidates) == 1 + and chunk.candidates[0].content.parts[0].text.startswith("[Invariant]") + and chunk.prompt_feedback is not None + and "BlockedReason.SAFETY" in str(chunk.prompt_feedback) + ) + + +def assert_is_streamed_refusal(response, expected_message_components: list[str]): + """ + Validates that the streamed response contains a refusal at the end (or as only message). + """ + num_chunks = 0 + for c in response: + num_chunks += 1 + + assert num_chunks >= 1, "Expected at least one chunk" + # last chunk must be a refusal + assert is_refusal(c) + + for emc in expected_message_components: + assert ( + emc in c.model_dump_json() + ), f"Expected message component {emc} not found in refusal message: {c.model_dump_json()}" diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index c013107..acc2f67 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -244,3 +244,108 @@ async def test_tool_call_guardrail_from_file( == "get_capital is called with Germany as argument" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_input_from_guardrail_from_file( + explorer_api_url, gateway_url, do_stream, push_to_explorer +): + """Test the message content guardrail.""" + if not os.getenv("INVARIANT_API_KEY"): + pytest.fail("No INVARIANT_API_KEY set, failing") + + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + + client = OpenAI( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/openai", + ) + + request = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Tell me more about Fight Club."}], + } + + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value + ) + + else: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=True, + ) + + for _ in chat_response: + pass + assert ( + "[Invariant] The request did not pass the guardrails" + in exc_info.value.message + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value.body + ) + + if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + # in case of input guardrailing, the pushed trace will not contain a response + assert len(trace["messages"]) == 1 + assert trace["messages"][0] == { + "role": "user", + "content": "Tell me more about Fight Club.", + } + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] + == "Users must not mention the magic phrase 'Fight Club'" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + ) diff --git a/tests/integration/resources/guardrails/find_capital_guardrails.py b/tests/integration/resources/guardrails/find_capital_guardrails.py index 8a9caf3..282720f 100644 --- a/tests/integration/resources/guardrails/find_capital_guardrails.py +++ b/tests/integration/resources/guardrails/find_capital_guardrails.py @@ -13,4 +13,10 @@ raise "Dublin detected in the response" if: raise "get_capital is called with Germany as argument" if: (call: ToolCall) call is tool:get_capital - call.function.arguments["country_name"] == "Germany" \ No newline at end of file + call.function.arguments["country_name"] == "Germany" + +# For input guardrailing specifically +raise "Users must not mention the magic phrase 'Fight Club'" if: + (msg: Message) + msg.role == "user" + "Fight Club" in msg.content \ No newline at end of file