From 050ec1ba58a74593da18fd4090fc3e1ccc736bfe Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 14:16:05 +0200 Subject: [PATCH] Fetch guardrails from explorer. These have higher precedence than than the guardrails from file. --- gateway/common/config_manager.py | 4 +- gateway/common/guardrails.py | 31 ++++++ gateway/common/request_context.py | 92 ++++++++++++++++ gateway/common/request_context_data.py | 16 --- gateway/integrations/explorer.py | 79 ++++++++++++++ gateway/integrations/guardrails.py | 68 +++++++++--- gateway/routes/anthropic.py | 97 ++++++++++++----- gateway/routes/gemini.py | 86 ++++++++++----- gateway/routes/open_ai.py | 101 ++++++++++++------ run.sh | 5 + ...est_anthropic_header_with_invariant_key.py | 6 +- .../test_anthropic_with_tool_call.py | 2 +- .../test_anthropic_without_tool_call.py | 4 +- .../test_generate_content_with_tool_calls.py | 2 +- ...est_generate_content_without_tool_calls.py | 7 +- .../open_ai/test_chat_with_tool_call.py | 4 +- .../open_ai/test_chat_without_tool_calls.py | 9 +- 17 files changed, 477 insertions(+), 136 deletions(-) create mode 100644 gateway/common/guardrails.py create mode 100644 gateway/common/request_context.py delete mode 100644 gateway/common/request_context_data.py diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 3e62b55..ac51c89 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -11,7 +11,7 @@ class GatewayConfig: """Common configurations for the Gateway Server.""" def __init__(self): - self.guardrails = self._load_guardrails_from_file() + self.guardrails_from_file = self._load_guardrails_from_file() def _load_guardrails_from_file(self) -> str: """ @@ -48,7 +48,7 @@ class GatewayConfig: raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e def __repr__(self) -> str: - return f"GatewayConfig(guardrails={repr(self.guardrails)})" + return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})" class GatewayConfigManager: diff --git a/gateway/common/guardrails.py b/gateway/common/guardrails.py new file mode 100644 index 0000000..cb7ef1e --- /dev/null +++ b/gateway/common/guardrails.py @@ -0,0 +1,31 @@ +"""Common guardrails data class.""" + +from enum import Enum +from typing import List + +from dataclasses import dataclass + + +class GuardrailAction(str, Enum): + """Enum representing the action to be taken for guardrail rules.""" + + BLOCK = "block" + LOG = "log" + + +@dataclass(frozen=True) +class Guardrail: + """Represents a single guardrail rule.""" + + id: str + name: str + content: str + action: GuardrailAction + + +@dataclass(frozen=True) +class DatasetGuardrails: + """Grouped guardrail rules separated by their action.""" + + blocking_guardrails: List[Guardrail] + logging_guardrails: List[Guardrail] diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py new file mode 100644 index 0000000..f2c8c3e --- /dev/null +++ b/gateway/common/request_context.py @@ -0,0 +1,92 @@ +"""Common Request context data class.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from common.config_manager import GatewayConfig +from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction + + +@dataclass(frozen=True) +class RequestContext: + """Structured context for a request. Must be created via `RequestContext.create()`.""" + + request_json: Dict[str, Any] + dataset_name: Optional[str] = None + invariant_authorization: Optional[str] = None + dataset_guardrails: Optional[DatasetGuardrails] = None + config: Dict[str, Any] = None + + _created_via_factory: bool = field( + default=False, init=True, repr=False, compare=False + ) + + def __post_init__(self): + if not self._created_via_factory: + raise RuntimeError( + "RequestContext must be created using RequestContext.create()" + ) + + @classmethod + def create( + cls, + request_json: Dict[str, Any], + dataset_name: Optional[str] = None, + invariant_authorization: Optional[str] = None, + dataset_guardrails: Optional[DatasetGuardrails] = None, + config: Optional[GatewayConfig] = None, + ) -> "RequestContext": + """Creates a new RequestContext instance, applying default guardrails if needed.""" + + # Convert GatewayConfig to a basic dict, excluding guardrails_from_file + context_config = { + key: value + for key, value in (config.__dict__.items() if config else {}) + if key != "guardrails_from_file" + } + + # If no guardrails are configured for the dataset on Explorer, + # and the config specifies guardrails_from_file, use that. + guardrails = dataset_guardrails + if ( + ( + not dataset_guardrails + or ( + not dataset_guardrails.blocking_guardrails + and not dataset_guardrails.logging_guardrails + ) + ) + and config + and config.guardrails_from_file + ): + # TODO: Support logging guardrails via file. + guardrails = DatasetGuardrails( + blocking_guardrails=[ + Guardrail( + id="default", + name="default", + content=config.guardrails_from_file, + action=GuardrailAction.BLOCK, + ) + ], + logging_guardrails=[], + ) + + return cls( + request_json=request_json, + dataset_name=dataset_name, + invariant_authorization=invariant_authorization, + dataset_guardrails=guardrails, + config=context_config, + _created_via_factory=True, + ) + + def __repr__(self) -> str: + return ( + f"RequestContext(" + f"request_json={self.request_json}, " + f"dataset_name={self.dataset_name}, " + f"invariant_authorization={self.invariant_authorization}, " + f"dataset_guardrails={self.dataset_guardrails}, " + f"config={self.config})" + ) diff --git a/gateway/common/request_context_data.py b/gateway/common/request_context_data.py deleted file mode 100644 index 8ae98ba..0000000 --- a/gateway/common/request_context_data.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Common Request context data class.""" - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from common.config_manager import GatewayConfig - - -@dataclass(frozen=True) -class RequestContextData: - """Request context data class.""" - - request_json: Dict[str, Any] - dataset_name: Optional[str] = None - invariant_authorization: Optional[str] = None - config: Optional[GatewayConfig] = None diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index fd0760b..29e0710 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -3,10 +3,13 @@ import os from typing import Any, Dict, List +from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse from invariant_sdk.types.annotations import AnnotationCreate +import httpx + DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -91,3 +94,79 @@ async def push_trace( except Exception as e: print(f"Failed to push trace: {e}") return {"error": str(e)} + + +async def fetch_guardrails_from_explorer( + dataset_name: str, invariant_authorization: str +) -> DatasetGuardrails: + """Get the guardrails for the dataset. + + Returns: + DatasetGuardrails: The guardrails for the dataset grouped by their action. + """ + + # TODO: Implement a single API in explorer backend which can return + # dataset details without requiring a username. + + client = httpx.AsyncClient( + base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), + headers={ + "Invariant-Authorization": invariant_authorization, + }, + ) + + # Get the user details. + user_info_response = await client.get("/api/v1/user/info") + if user_info_response.status_code != 200: + raise ValueError( + f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}" + ) + user_details = user_info_response.json() + username = user_details["username"] + + # Get the dataset policies. + policies_response = await client.get( + f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy" + ) + if policies_response.status_code != 200: + if policies_response.status_code == 404: + # If the dataset does not exist, return empty guardrails. + return DatasetGuardrails( + blocking_guardrails=[], + logging_guardrails=[], + ) + raise ValueError( + f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}" + ) + policies_details = policies_response.json() + guardrails = policies_details.get("policies", []) + + blocking_guardrails = [] + logging_guardrails = [] + for g in guardrails: + action = g["action"] + + if not g["enabled"]: + # Skip guardrails that are not enabled. + continue + + if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG): + print("[Warning] Skipping unknown guardrail action: ", action) + continue + + guardrail = Guardrail( + id=g["id"], + name=g["name"], + content=g["content"], + action=GuardrailAction(action), + ) + + if action == GuardrailAction.BLOCK: + blocking_guardrails.append(guardrail) + else: + logging_guardrails.append(guardrail) + + return DatasetGuardrails( + blocking_guardrails=blocking_guardrails, + logging_guardrails=logging_guardrails, + ) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index b0f3601..412418e 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -1,13 +1,15 @@ """Utility functions for Guardrails execution.""" import asyncio +import json import os import time from typing import Any, Dict, List from functools import wraps import httpx -from common.request_context_data import RequestContextData +from common.guardrails import Guardrail +from common.request_context import RequestContext DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -81,21 +83,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None: result.raise_for_status() -async def preload_guardrails(context: "RequestContextData") -> None: +async def preload_guardrails(context: "RequestContext") -> None: """ Preloads the guardrails for faster checking later. Args: - context: RequestContextData object. + context: RequestContext object. """ - if not context.config or not context.config.guardrails: + if not context.dataset_guardrails: return try: - task = asyncio.create_task( - _preload(context.config.guardrails, context.invariant_authorization) - ) - asyncio.shield(task) + # Move these calls to a batch preload/validate API. + for blocking_guardrail in context.dataset_guardrails.blocking_guardrails: + task = asyncio.create_task( + _preload(blocking_guardrail.content, context.invariant_authorization) + ) + asyncio.shield(task) + for logging_guadrail in context.dataset_guardrails.logging_guardrails: + task = asyncio.create_task( + _preload(logging_guadrail.content, context.invariant_authorization) + ) + asyncio.shield(task) except Exception as e: print(f"Error scheduling preload_guardrails task: {e}") @@ -322,14 +331,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse): async def check_guardrails( - messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str + messages: List[Dict[str, Any]], + guardrails: List[Guardrail], + invariant_authorization: str, ) -> Dict[str, Any]: """ Checks guardrails on the list of messages. + This calls the batch check API of the Guardrails service. Args: messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. - guardrails (str): The guardrails to check against. + guardrails (List[Guardrail]): The guardrails to check against. invariant_authorization (str): Value of the invariant-authorization header. @@ -339,9 +351,34 @@ async def check_guardrails( async with httpx.AsyncClient() as client: url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") try: + print( + "Hello there this is the request to guardrails: ", + json.dumps( + { + "messages": messages, + "policies": [g.content for g in guardrails], + }, + indent=2, + ), + flush=True, + ) + print( + "Hello there this is the request to guardrails: ", + json.dumps( + { + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + indent=2, + ), + flush=True, + ) result = await client.post( - f"{url}/api/v1/policy/check", - json={"messages": messages, "policy": guardrails}, + f"{url}/api/v1/policy/check/batch", + json={ + "messages": messages, + "policies": [g.content for g in guardrails], + }, headers={ "Authorization": invariant_authorization, "Accept": "application/json", @@ -352,7 +389,12 @@ async def check_guardrails( f"Guardrails check failed: {result.status_code} - {result.text}" ) print(f"Guardrail check response: {result.json()}") - return result.json() + + guardrails_result = result.json() + aggregated_errors = {"errors": []} + for res in guardrails_result.get("result", []): + aggregated_errors["errors"].extend(res.get("errors", [])) + return aggregated_errors except Exception as e: print(f"Failed to verify guardrails: {e}") return {"error": str(e)} diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 9e5187b..9905258 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -14,11 +14,16 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.request_context_data import RequestContextData +from common.guardrails import GuardrailAction +from common.request_context import RequestContext from converters.anthropic_to_invariant import ( convert_anthropic_to_invariant_message_format, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -83,10 +88,17 @@ async def anthropic_v1_messages_gateway( data=request_body, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset from explorer. + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + dataset_guardrails=dataset_guardrails, config=config, ) asyncio.create_task(preload_guardrails(context)) @@ -97,7 +109,7 @@ async def anthropic_v1_messages_gateway( def create_metadata( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = {k: v for k, v in context.request_json.items() if k != "messages"} @@ -108,7 +120,7 @@ def create_metadata( def combine_request_and_response_messages( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContext, json_response: dict[str, Any] ): """Combine the request and response messages""" messages = [] @@ -123,23 +135,32 @@ def combine_request_and_response_messages( async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContext, action: GuardrailAction, json_response: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" + # Determine which guardrails to apply based on the action + guardrails = ( + context.dataset_guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.dataset_guardrails.blocking_guardrails + ) + if not guardrails: + return {} + messages = combine_request_and_response_messages(context, json_response) converted_messages = convert_anthropic_to_invariant_message_format(messages) # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=converted_messages, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result async def push_to_explorer( - context: RequestContextData, + context: RequestContext, merged_response: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -163,14 +184,16 @@ async def push_to_explorer( class InstrumentedAnthropicResponse(InstrumentedResponse): + """Instrumented response for Anthropic API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ): super().__init__() - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.anthropic_request: httpx.Request = anthropic_request @@ -184,9 +207,9 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, json_response={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -264,10 +287,17 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): assert self.json_response is not None, "json_response is None" assert self.response_string is not None, "response_string is None" - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: # Block on the guardrails check guardrails_execution_result = await get_guardrails_check_result( - self.context, self.json_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.json_response, + ) + print( + "Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ", + guardrails_execution_result, + flush=True, ) if guardrails_execution_result.get("errors", []): guardrail_response_string = json.dumps( @@ -306,7 +336,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): async def handle_non_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ) -> Response: @@ -320,17 +350,19 @@ async def handle_non_streaming_response( return await response.instrumented_request() -class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): +class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): + """Instrumented streaming response for Anthropic API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.anthropic_request: httpx.Request = anthropic_request @@ -342,9 +374,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -392,6 +426,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): + """Process the chunk and update the merged_response""" decoded_chunk = chunk.decode().strip() if not decoded_chunk: return @@ -400,14 +435,17 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): process_chunk(decoded_chunk, self.merged_response) # on last stream chunk, run output guardrails - if ( - "event: message_stop" in decoded_chunk - and self.context.config - and self.context.config.guardrails - ): + if "event: message_stop" in decoded_chunk and self.context.dataset_guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, + ) + print( + "Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ", + self.guardrails_execution_result, + flush=True, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -420,7 +458,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): } ) - # yield an extra error chunk (without preventing the original chunk to go through after, + # yield an extra error chunk (without preventing the original chunk + # to go through after, # so client gets the proper message_stop event still) return ExtraItem( value=f"event: error\ndata: {error_chunk}\n\n".encode() @@ -440,12 +479,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): async def handle_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ) -> StreamingResponse: """Handles streaming Anthropic responses""" - response = InstrumentedAnthropicStreamingResposne( + response = InstrumentedAnthropicStreamingResponse( context=context, client=client, anthropic_request=anthropic_request, diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index ae86bf4..1e21b90 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -14,9 +14,14 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.request_context_data import RequestContextData +from common.guardrails import GuardrailAction +from common.request_context import RequestContext from converters.gemini_to_invariant import convert_request, convert_response -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -76,10 +81,17 @@ async def gemini_generate_content_gateway( headers=headers, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + dataset_guardrails=dataset_guardrails, config=config, ) asyncio.create_task(preload_guardrails(context)) @@ -98,16 +110,18 @@ async def gemini_generate_content_gateway( class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): + """Instrumented streaming response for Gemini API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ): super().__init__() # request data - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.gemini_request: httpx.Request = gemini_request @@ -124,6 +138,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): location: Literal["request", "response"], guardrails_execution_result: dict[str, Any], ) -> dict: + """Create a refusal response for the given request or response""" return { "candidates": [ { @@ -157,10 +172,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): } async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -184,6 +202,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): ) async def event_generator(self): + """Event generator for streaming responses""" response = await self.client.send(self.gemini_request, stream=True) if response.status_code != 200: @@ -199,6 +218,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): + """Processes each chunk of the streaming response""" chunk_text = chunk.decode().strip() if not chunk_text: return @@ -210,12 +230,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): if ( self.merged_response.get("candidates", []) and self.merged_response.get("candidates")[0].get("finishReason", "") - and self.context.config - and self.context.config.guardrails + and self.context.dataset_guardrails ): # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -254,7 +275,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): async def stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ) -> Response: @@ -332,7 +353,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) -> def create_metadata( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = { @@ -352,23 +373,32 @@ def create_metadata( async def get_guardrails_check_result( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, action: GuardrailAction, response_json: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" + # Determine which guardrails to apply based on the action + guardrails = ( + context.dataset_guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.dataset_guardrails.blocking_guardrails + ) + if not guardrails: + return {} + converted_requests = convert_request(context.request_json) converted_responses = convert_response(response_json) # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=converted_requests + converted_responses, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result async def push_to_explorer( - context: RequestContextData, + context: RequestContext, response_json: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -391,16 +421,18 @@ async def push_to_explorer( class InstrumentedGeminiResponse(InstrumentedResponse): + """Instrumented response for Gemini API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ): super().__init__() # request data - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.gemini_request: httpx.Request = gemini_request @@ -412,10 +444,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse): self.guardrails_execution_result: Optional[dict[str, Any]] = None async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -463,6 +498,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): ) async def request(self): + """Makes the request to the Gemini API and return the response""" self.response = await self.client.send(self.gemini_request) response_string = self.response.text @@ -492,10 +528,12 @@ class InstrumentedGeminiResponse(InstrumentedResponse): response_string = json.dumps(self.response_json) response_code = self.response.status_code - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: # Block on the guardrails check guardrails_execution_result = await get_guardrails_check_result( - self.context, self.response_json + self.context, + action=GuardrailAction.BLOCK, + response_json=self.response_json, ) if guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -539,7 +577,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): async def handle_non_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ) -> Response: diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index e3cec41..ada5309 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -14,8 +14,13 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.request_context_data import RequestContextData -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from common.guardrails import GuardrailAction +from common.request_context import RequestContext +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -72,10 +77,17 @@ async def openai_chat_completions_gateway( headers=headers, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + dataset_guardrails=dataset_guardrails, config=config, ) asyncio.create_task(preload_guardrails(context)) @@ -92,19 +104,20 @@ async def openai_chat_completions_gateway( class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): """ - Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + Does a streaming OpenAI completion request at the core, but also checks guardrails + before (concurrent) and after the request. """ def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.open_ai_request: httpx.Request = open_ai_request @@ -131,10 +144,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.tool_call_mapping_by_index = {} async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -164,6 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) async def on_chunk(self, chunk): + """Processes each chunk of the stream and checks guardrails at the end of the stream""" # process and check each chunk chunk_text = chunk.decode().strip() if not chunk_text: @@ -179,14 +198,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) # check guardrails at the end of the stream (on the '[DONE]' SSE chunk.) - if ( - "data: [DONE]" in chunk_text - and self.context.config - and self.context.config.guardrails - ): + if "data: [DONE]" in chunk_text and self.context.dataset_guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -214,10 +231,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) async def event_generator(self): - """ - Actual OpenAI stream response. - """ - + """Actual OpenAI stream response.""" response = await self.client.send(self.open_ai_request, stream=True) if response.status_code != 200: error_content = await response.aread() @@ -234,7 +248,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): async def handle_stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ) -> Response: @@ -389,7 +403,7 @@ def update_existing_choice_with_delta( def create_metadata( - context: RequestContextData, merged_response: dict[str, Any] + context: RequestContext, merged_response: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = { @@ -409,7 +423,7 @@ def create_metadata( async def push_to_explorer( - context: RequestContextData, + context: RequestContext, merged_response: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -437,18 +451,28 @@ async def push_to_explorer( async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] | None = None + context: RequestContext, + action: GuardrailAction, + json_response: dict[str, Any] | None = None, ) -> dict[str, Any]: """Get the guardrails check result""" - messages = list(context.request_json.get("messages", [])) + # Determine which guardrails to apply based on the action + guardrails = ( + context.dataset_guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.dataset_guardrails.blocking_guardrails + ) + if not guardrails: + return {} + messages = list(context.request_json.get("messages", [])) if json_response is not None: messages += [choice["message"] for choice in json_response.get("choices", [])] # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=messages, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result @@ -456,19 +480,20 @@ async def get_guardrails_check_result( class InstrumentedOpenAIResponse(InstrumentedResponse): """ - Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + Does an OpenAI completion request at the core, but also checks guardrails + before (concurrent) and after the request. """ def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.open_ai_request: httpx.Request = open_ai_request @@ -480,11 +505,14 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.guardrails_execution_result: Optional[dict] = None async def on_start(self): - """Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)""" - if self.context.config and self.context.config.guardrails: + """ + Checks guardrails in a pipelined fashion, before processing + the first chunk (for input guardrailing) + """ + if self.context.dataset_guardrails: # block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context + self.context, action=GuardrailAction.BLOCK ) if self.guardrails_execution_result.get("errors", []): # Push annotated trace to the explorer - don't block on its response @@ -542,7 +570,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): async def on_end(self): """Postprocesses the OpenAI response and potentially replace it with a guardrails error.""" - # these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed) + # these two request outputs are guaranteed to be available by the time we reach + # this point (after self.request() was executed) # nevertheless, we check for them to avoid any potential issues assert ( self.response is not None @@ -555,10 +584,12 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): response_code = self.response.status_code # if we have guardrails, check the response - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: # run guardrails again, this time on request + response self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.json_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.json_response, ) if self.guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -601,7 +632,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): async def handle_non_stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ) -> Response: diff --git a/run.sh b/run.sh index f666f96..72ad68e 100755 --- a/run.sh +++ b/run.sh @@ -93,6 +93,11 @@ integration_tests() { fi echo "File successfully downloaded: $FILE" + if [[ -z "$INVARIANT_API_KEY" ]]; then + echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests." + exit 1 + fi + TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py" if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then diff --git a/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py b/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py index a2e3db0..1597846 100644 --- a/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py @@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header( """Test the Anthropic gateway with Invariant key in the Anthropic key""" anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - { - "ANTHROPIC_API_KEY": anthropic_api_key - + ";invariant-auth=" - }, + {"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix}, ): client = anthropic.Anthropic( http_client=Client(), diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 89591ea..560ada0 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -26,7 +26,7 @@ class WeatherAgent: def __init__(self, gateway_url, push_to_explorer): self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY") self.client = anthropic.Anthropic( http_client=Client( headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, diff --git a/tests/integration/anthropic/test_anthropic_without_tool_call.py b/tests/integration/anthropic/test_anthropic_without_tool_call.py index a397c25..c47eaba 100644 --- a/tests/integration/anthropic/test_anthropic_without_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_without_tool_call.py @@ -26,7 +26,7 @@ async def test_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY") client = anthropic.Anthropic( http_client=Client( @@ -91,7 +91,7 @@ async def test_streaming_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY") client = anthropic.Anthropic( http_client=Client( diff --git a/tests/integration/gemini/test_generate_content_with_tool_calls.py b/tests/integration/gemini/test_generate_content_with_tool_calls.py index bcf838e..7d2486e 100644 --- a/tests/integration/gemini/test_generate_content_with_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_with_tool_calls.py @@ -151,7 +151,7 @@ async def test_generate_content_with_tool_call( if push_to_explorer else f"{gateway_url}/api/v1/gateway/gemini", "headers": { - "invariant-authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests }, ) diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index d16797a..cf3f42f 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -36,7 +36,7 @@ async def test_generate_content( if push_to_explorer else f"{gateway_url}/api/v1/gateway/gemini", "headers": { - "invariant-authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests }, ) @@ -123,7 +123,7 @@ async def test_generate_content_with_image( if push_to_explorer else f"{gateway_url}/api/v1/gateway/gemini", "headers": { - "invariant-authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests }, ) @@ -181,9 +181,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( """Test the generate content gateway calls with the Invariant API Key in the Gemini Key header.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" gemini_api_key = os.getenv("GEMINI_API_KEY") + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - {"GEMINI_API_KEY": gemini_api_key + ";invariant-auth="}, + {"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix}, ): client = genai.Client( api_key=os.getenv("GEMINI_API_KEY"), diff --git a/tests/integration/open_ai/test_chat_with_tool_call.py b/tests/integration/open_ai/test_chat_with_tool_call.py index 69ad6af..5e96db9 100644 --- a/tests/integration/open_ai/test_chat_with_tool_call.py +++ b/tests/integration/open_ai/test_chat_with_tool_call.py @@ -32,7 +32,7 @@ async def test_chat_completion_with_tool_call_without_streaming( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" @@ -150,7 +150,7 @@ async def test_chat_completion_with_tool_call_with_streaming( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" diff --git a/tests/integration/open_ai/test_chat_without_tool_calls.py b/tests/integration/open_ai/test_chat_without_tool_calls.py index e9134b3..a8745fb 100644 --- a/tests/integration/open_ai/test_chat_without_tool_calls.py +++ b/tests/integration/open_ai/test_chat_without_tool_calls.py @@ -34,7 +34,7 @@ async def test_chat_completion( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" @@ -107,7 +107,7 @@ async def test_chat_completion_with_image( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" @@ -189,9 +189,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header( """Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" openai_api_key = os.getenv("OPENAI_API_KEY") + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - {"OPENAI_API_KEY": openai_api_key + ";invariant-auth="}, + {"OPENAI_API_KEY": openai_api_key + invariant_key_suffix}, ): client = OpenAI( http_client=Client(), @@ -252,7 +253,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream): client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",