From f45a973f512c403b2a19c8768449554f99628dd8 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 10:04:00 +0200 Subject: [PATCH 1/7] Small formatting changes. --- .env | 2 +- gateway/.env | 6 +++--- gateway/routes/anthropic.py | 10 +++++----- gateway/routes/gemini.py | 8 ++++---- gateway/routes/open_ai.py | 7 ++++--- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/.env b/.env index af1265e..6fa5ec4 100644 --- a/.env +++ b/.env @@ -3,4 +3,4 @@ # If you want to push to a local instance of explorer, then specify the app-api docker container name like: # http://:8000 to push to the local explorer instance. INVARIANT_API_URL=https://explorer.invariantlabs.ai -GUADRAILS_API_URL=https://explorer.invariantlabs.ai +GUADRAILS_API_URL=https://explorer.invariantlabs.ai \ No newline at end of file diff --git a/gateway/.env b/gateway/.env index 78d3391..c84df8b 100644 --- a/gateway/.env +++ b/gateway/.env @@ -1,4 +1,4 @@ POSTGRES_USER=postgres - POSTGRES_PASSWORD=postgres - POSTGRES_DB=invariantmonitor - POSTGRES_HOST=database \ No newline at end of file +POSTGRES_PASSWORD=postgres +POSTGRES_DB=invariantmonitor +POSTGRES_HOST=database \ No newline at end of file diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 2f3e243..9e5187b 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -5,20 +5,20 @@ import json from typing import Any, Optional import httpx -from regex import R -from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse + +from common.authorization import extract_authorization_from_headers +from common.config_manager import GatewayConfig, GatewayConfigManager from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from common.request_context_data import RequestContextData from converters.anthropic_to_invariant import ( convert_anthropic_to_invariant_message_format, ) -from common.authorization import extract_authorization_from_headers -from common.request_context_data import RequestContextData +from integrations.explorer import create_annotations_from_guardrails_errors, push_trace from integrations.guardrails import ( ExtraItem, InstrumentedResponse, diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 59d0874..ae86bf4 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -5,16 +5,18 @@ import json from typing import Any, Literal, Optional import httpx -from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response from fastapi.responses import StreamingResponse + +from common.authorization import extract_authorization_from_headers +from common.config_manager import GatewayConfig, GatewayConfigManager from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData from converters.gemini_to_invariant import convert_request, convert_response +from integrations.explorer import create_annotations_from_guardrails_errors, push_trace from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -23,8 +25,6 @@ from integrations.guardrails import ( preload_guardrails, check_guardrails, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardrails import check_guardrails, preload_guardrails gateway = APIRouter() diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 6ef3808..e3cec41 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -5,13 +5,16 @@ import json from typing import Any, Optional import httpx -from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import StreamingResponse + +from common.authorization import extract_authorization_from_headers +from common.config_manager import GatewayConfig, GatewayConfigManager from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) +from common.request_context_data import RequestContextData from integrations.explorer import create_annotations_from_guardrails_errors, push_trace from integrations.guardrails import ( ExtraItem, @@ -20,8 +23,6 @@ from integrations.guardrails import ( check_guardrails, preload_guardrails, ) -from common.authorization import extract_authorization_from_headers -from common.request_context_data import RequestContextData gateway = APIRouter() From 050ec1ba58a74593da18fd4090fc3e1ccc736bfe Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 14:16:05 +0200 Subject: [PATCH 2/7] Fetch guardrails from explorer. These have higher precedence than than the guardrails from file. --- gateway/common/config_manager.py | 4 +- gateway/common/guardrails.py | 31 ++++++ gateway/common/request_context.py | 92 ++++++++++++++++ gateway/common/request_context_data.py | 16 --- gateway/integrations/explorer.py | 79 ++++++++++++++ gateway/integrations/guardrails.py | 68 +++++++++--- gateway/routes/anthropic.py | 97 ++++++++++++----- gateway/routes/gemini.py | 86 ++++++++++----- gateway/routes/open_ai.py | 101 ++++++++++++------ run.sh | 5 + ...est_anthropic_header_with_invariant_key.py | 6 +- .../test_anthropic_with_tool_call.py | 2 +- .../test_anthropic_without_tool_call.py | 4 +- .../test_generate_content_with_tool_calls.py | 2 +- ...est_generate_content_without_tool_calls.py | 7 +- .../open_ai/test_chat_with_tool_call.py | 4 +- .../open_ai/test_chat_without_tool_calls.py | 9 +- 17 files changed, 477 insertions(+), 136 deletions(-) create mode 100644 gateway/common/guardrails.py create mode 100644 gateway/common/request_context.py delete mode 100644 gateway/common/request_context_data.py diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 3e62b55..ac51c89 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -11,7 +11,7 @@ class GatewayConfig: """Common configurations for the Gateway Server.""" def __init__(self): - self.guardrails = self._load_guardrails_from_file() + self.guardrails_from_file = self._load_guardrails_from_file() def _load_guardrails_from_file(self) -> str: """ @@ -48,7 +48,7 @@ class GatewayConfig: raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e def __repr__(self) -> str: - return f"GatewayConfig(guardrails={repr(self.guardrails)})" + return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})" class GatewayConfigManager: diff --git a/gateway/common/guardrails.py b/gateway/common/guardrails.py new file mode 100644 index 0000000..cb7ef1e --- /dev/null +++ b/gateway/common/guardrails.py @@ -0,0 +1,31 @@ +"""Common guardrails data class.""" + +from enum import Enum +from typing import List + +from dataclasses import dataclass + + +class GuardrailAction(str, Enum): + """Enum representing the action to be taken for guardrail rules.""" + + BLOCK = "block" + LOG = "log" + + +@dataclass(frozen=True) +class Guardrail: + """Represents a single guardrail rule.""" + + id: str + name: str + content: str + action: GuardrailAction + + +@dataclass(frozen=True) +class DatasetGuardrails: + """Grouped guardrail rules separated by their action.""" + + blocking_guardrails: List[Guardrail] + logging_guardrails: List[Guardrail] diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py new file mode 100644 index 0000000..f2c8c3e --- /dev/null +++ b/gateway/common/request_context.py @@ -0,0 +1,92 @@ +"""Common Request context data class.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from common.config_manager import GatewayConfig +from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction + + +@dataclass(frozen=True) +class RequestContext: + """Structured context for a request. Must be created via `RequestContext.create()`.""" + + request_json: Dict[str, Any] + dataset_name: Optional[str] = None + invariant_authorization: Optional[str] = None + dataset_guardrails: Optional[DatasetGuardrails] = None + config: Dict[str, Any] = None + + _created_via_factory: bool = field( + default=False, init=True, repr=False, compare=False + ) + + def __post_init__(self): + if not self._created_via_factory: + raise RuntimeError( + "RequestContext must be created using RequestContext.create()" + ) + + @classmethod + def create( + cls, + request_json: Dict[str, Any], + dataset_name: Optional[str] = None, + invariant_authorization: Optional[str] = None, + dataset_guardrails: Optional[DatasetGuardrails] = None, + config: Optional[GatewayConfig] = None, + ) -> "RequestContext": + """Creates a new RequestContext instance, applying default guardrails if needed.""" + + # Convert GatewayConfig to a basic dict, excluding guardrails_from_file + context_config = { + key: value + for key, value in (config.__dict__.items() if config else {}) + if key != "guardrails_from_file" + } + + # If no guardrails are configured for the dataset on Explorer, + # and the config specifies guardrails_from_file, use that. + guardrails = dataset_guardrails + if ( + ( + not dataset_guardrails + or ( + not dataset_guardrails.blocking_guardrails + and not dataset_guardrails.logging_guardrails + ) + ) + and config + and config.guardrails_from_file + ): + # TODO: Support logging guardrails via file. + guardrails = DatasetGuardrails( + blocking_guardrails=[ + Guardrail( + id="default", + name="default", + content=config.guardrails_from_file, + action=GuardrailAction.BLOCK, + ) + ], + logging_guardrails=[], + ) + + return cls( + request_json=request_json, + dataset_name=dataset_name, + invariant_authorization=invariant_authorization, + dataset_guardrails=guardrails, + config=context_config, + _created_via_factory=True, + ) + + def __repr__(self) -> str: + return ( + f"RequestContext(" + f"request_json={self.request_json}, " + f"dataset_name={self.dataset_name}, " + f"invariant_authorization={self.invariant_authorization}, " + f"dataset_guardrails={self.dataset_guardrails}, " + f"config={self.config})" + ) diff --git a/gateway/common/request_context_data.py b/gateway/common/request_context_data.py deleted file mode 100644 index 8ae98ba..0000000 --- a/gateway/common/request_context_data.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Common Request context data class.""" - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from common.config_manager import GatewayConfig - - -@dataclass(frozen=True) -class RequestContextData: - """Request context data class.""" - - request_json: Dict[str, Any] - dataset_name: Optional[str] = None - invariant_authorization: Optional[str] = None - config: Optional[GatewayConfig] = None diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index fd0760b..29e0710 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -3,10 +3,13 @@ import os from typing import Any, Dict, List +from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse from invariant_sdk.types.annotations import AnnotationCreate +import httpx + DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -91,3 +94,79 @@ async def push_trace( except Exception as e: print(f"Failed to push trace: {e}") return {"error": str(e)} + + +async def fetch_guardrails_from_explorer( + dataset_name: str, invariant_authorization: str +) -> DatasetGuardrails: + """Get the guardrails for the dataset. + + Returns: + DatasetGuardrails: The guardrails for the dataset grouped by their action. + """ + + # TODO: Implement a single API in explorer backend which can return + # dataset details without requiring a username. + + client = httpx.AsyncClient( + base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), + headers={ + "Invariant-Authorization": invariant_authorization, + }, + ) + + # Get the user details. + user_info_response = await client.get("/api/v1/user/info") + if user_info_response.status_code != 200: + raise ValueError( + f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}" + ) + user_details = user_info_response.json() + username = user_details["username"] + + # Get the dataset policies. + policies_response = await client.get( + f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy" + ) + if policies_response.status_code != 200: + if policies_response.status_code == 404: + # If the dataset does not exist, return empty guardrails. + return DatasetGuardrails( + blocking_guardrails=[], + logging_guardrails=[], + ) + raise ValueError( + f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}" + ) + policies_details = policies_response.json() + guardrails = policies_details.get("policies", []) + + blocking_guardrails = [] + logging_guardrails = [] + for g in guardrails: + action = g["action"] + + if not g["enabled"]: + # Skip guardrails that are not enabled. + continue + + if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG): + print("[Warning] Skipping unknown guardrail action: ", action) + continue + + guardrail = Guardrail( + id=g["id"], + name=g["name"], + content=g["content"], + action=GuardrailAction(action), + ) + + if action == GuardrailAction.BLOCK: + blocking_guardrails.append(guardrail) + else: + logging_guardrails.append(guardrail) + + return DatasetGuardrails( + blocking_guardrails=blocking_guardrails, + logging_guardrails=logging_guardrails, + ) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index b0f3601..412418e 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -1,13 +1,15 @@ """Utility functions for Guardrails execution.""" import asyncio +import json import os import time from typing import Any, Dict, List from functools import wraps import httpx -from common.request_context_data import RequestContextData +from common.guardrails import Guardrail +from common.request_context import RequestContext DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -81,21 +83,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None: result.raise_for_status() -async def preload_guardrails(context: "RequestContextData") -> None: +async def preload_guardrails(context: "RequestContext") -> None: """ Preloads the guardrails for faster checking later. Args: - context: RequestContextData object. + context: RequestContext object. """ - if not context.config or not context.config.guardrails: + if not context.dataset_guardrails: return try: - task = asyncio.create_task( - _preload(context.config.guardrails, context.invariant_authorization) - ) - asyncio.shield(task) + # Move these calls to a batch preload/validate API. + for blocking_guardrail in context.dataset_guardrails.blocking_guardrails: + task = asyncio.create_task( + _preload(blocking_guardrail.content, context.invariant_authorization) + ) + asyncio.shield(task) + for logging_guadrail in context.dataset_guardrails.logging_guardrails: + task = asyncio.create_task( + _preload(logging_guadrail.content, context.invariant_authorization) + ) + asyncio.shield(task) except Exception as e: print(f"Error scheduling preload_guardrails task: {e}") @@ -322,14 +331,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse): async def check_guardrails( - messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str + messages: List[Dict[str, Any]], + guardrails: List[Guardrail], + invariant_authorization: str, ) -> Dict[str, Any]: """ Checks guardrails on the list of messages. + This calls the batch check API of the Guardrails service. Args: messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. - guardrails (str): The guardrails to check against. + guardrails (List[Guardrail]): The guardrails to check against. invariant_authorization (str): Value of the invariant-authorization header. @@ -339,9 +351,34 @@ async def check_guardrails( async with httpx.AsyncClient() as client: url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") try: + print( + "Hello there this is the request to guardrails: ", + json.dumps( + { + "messages": messages, + "policies": [g.content for g in guardrails], + }, + indent=2, + ), + flush=True, + ) + print( + "Hello there this is the request to guardrails: ", + json.dumps( + { + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + indent=2, + ), + flush=True, + ) result = await client.post( - f"{url}/api/v1/policy/check", - json={"messages": messages, "policy": guardrails}, + f"{url}/api/v1/policy/check/batch", + json={ + "messages": messages, + "policies": [g.content for g in guardrails], + }, headers={ "Authorization": invariant_authorization, "Accept": "application/json", @@ -352,7 +389,12 @@ async def check_guardrails( f"Guardrails check failed: {result.status_code} - {result.text}" ) print(f"Guardrail check response: {result.json()}") - return result.json() + + guardrails_result = result.json() + aggregated_errors = {"errors": []} + for res in guardrails_result.get("result", []): + aggregated_errors["errors"].extend(res.get("errors", [])) + return aggregated_errors except Exception as e: print(f"Failed to verify guardrails: {e}") return {"error": str(e)} diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 9e5187b..9905258 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -14,11 +14,16 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.request_context_data import RequestContextData +from common.guardrails import GuardrailAction +from common.request_context import RequestContext from converters.anthropic_to_invariant import ( convert_anthropic_to_invariant_message_format, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -83,10 +88,17 @@ async def anthropic_v1_messages_gateway( data=request_body, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset from explorer. + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + dataset_guardrails=dataset_guardrails, config=config, ) asyncio.create_task(preload_guardrails(context)) @@ -97,7 +109,7 @@ async def anthropic_v1_messages_gateway( def create_metadata( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = {k: v for k, v in context.request_json.items() if k != "messages"} @@ -108,7 +120,7 @@ def create_metadata( def combine_request_and_response_messages( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContext, json_response: dict[str, Any] ): """Combine the request and response messages""" messages = [] @@ -123,23 +135,32 @@ def combine_request_and_response_messages( async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContext, action: GuardrailAction, json_response: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" + # Determine which guardrails to apply based on the action + guardrails = ( + context.dataset_guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.dataset_guardrails.blocking_guardrails + ) + if not guardrails: + return {} + messages = combine_request_and_response_messages(context, json_response) converted_messages = convert_anthropic_to_invariant_message_format(messages) # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=converted_messages, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result async def push_to_explorer( - context: RequestContextData, + context: RequestContext, merged_response: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -163,14 +184,16 @@ async def push_to_explorer( class InstrumentedAnthropicResponse(InstrumentedResponse): + """Instrumented response for Anthropic API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ): super().__init__() - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.anthropic_request: httpx.Request = anthropic_request @@ -184,9 +207,9 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, json_response={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -264,10 +287,17 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): assert self.json_response is not None, "json_response is None" assert self.response_string is not None, "response_string is None" - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: # Block on the guardrails check guardrails_execution_result = await get_guardrails_check_result( - self.context, self.json_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.json_response, + ) + print( + "Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ", + guardrails_execution_result, + flush=True, ) if guardrails_execution_result.get("errors", []): guardrail_response_string = json.dumps( @@ -306,7 +336,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): async def handle_non_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ) -> Response: @@ -320,17 +350,19 @@ async def handle_non_streaming_response( return await response.instrumented_request() -class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): +class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): + """Instrumented streaming response for Anthropic API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.anthropic_request: httpx.Request = anthropic_request @@ -342,9 +374,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -392,6 +426,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): + """Process the chunk and update the merged_response""" decoded_chunk = chunk.decode().strip() if not decoded_chunk: return @@ -400,14 +435,17 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): process_chunk(decoded_chunk, self.merged_response) # on last stream chunk, run output guardrails - if ( - "event: message_stop" in decoded_chunk - and self.context.config - and self.context.config.guardrails - ): + if "event: message_stop" in decoded_chunk and self.context.dataset_guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, + ) + print( + "Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ", + self.guardrails_execution_result, + flush=True, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -420,7 +458,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): } ) - # yield an extra error chunk (without preventing the original chunk to go through after, + # yield an extra error chunk (without preventing the original chunk + # to go through after, # so client gets the proper message_stop event still) return ExtraItem( value=f"event: error\ndata: {error_chunk}\n\n".encode() @@ -440,12 +479,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): async def handle_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ) -> StreamingResponse: """Handles streaming Anthropic responses""" - response = InstrumentedAnthropicStreamingResposne( + response = InstrumentedAnthropicStreamingResponse( context=context, client=client, anthropic_request=anthropic_request, diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index ae86bf4..1e21b90 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -14,9 +14,14 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.request_context_data import RequestContextData +from common.guardrails import GuardrailAction +from common.request_context import RequestContext from converters.gemini_to_invariant import convert_request, convert_response -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -76,10 +81,17 @@ async def gemini_generate_content_gateway( headers=headers, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + dataset_guardrails=dataset_guardrails, config=config, ) asyncio.create_task(preload_guardrails(context)) @@ -98,16 +110,18 @@ async def gemini_generate_content_gateway( class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): + """Instrumented streaming response for Gemini API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ): super().__init__() # request data - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.gemini_request: httpx.Request = gemini_request @@ -124,6 +138,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): location: Literal["request", "response"], guardrails_execution_result: dict[str, Any], ) -> dict: + """Create a refusal response for the given request or response""" return { "candidates": [ { @@ -157,10 +172,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): } async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -184,6 +202,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): ) async def event_generator(self): + """Event generator for streaming responses""" response = await self.client.send(self.gemini_request, stream=True) if response.status_code != 200: @@ -199,6 +218,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): + """Processes each chunk of the streaming response""" chunk_text = chunk.decode().strip() if not chunk_text: return @@ -210,12 +230,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): if ( self.merged_response.get("candidates", []) and self.merged_response.get("candidates")[0].get("finishReason", "") - and self.context.config - and self.context.config.guardrails + and self.context.dataset_guardrails ): # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -254,7 +275,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): async def stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ) -> Response: @@ -332,7 +353,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) -> def create_metadata( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = { @@ -352,23 +373,32 @@ def create_metadata( async def get_guardrails_check_result( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, action: GuardrailAction, response_json: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" + # Determine which guardrails to apply based on the action + guardrails = ( + context.dataset_guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.dataset_guardrails.blocking_guardrails + ) + if not guardrails: + return {} + converted_requests = convert_request(context.request_json) converted_responses = convert_response(response_json) # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=converted_requests + converted_responses, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result async def push_to_explorer( - context: RequestContextData, + context: RequestContext, response_json: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -391,16 +421,18 @@ async def push_to_explorer( class InstrumentedGeminiResponse(InstrumentedResponse): + """Instrumented response for Gemini API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ): super().__init__() # request data - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.gemini_request: httpx.Request = gemini_request @@ -412,10 +444,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse): self.guardrails_execution_result: Optional[dict[str, Any]] = None async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -463,6 +498,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): ) async def request(self): + """Makes the request to the Gemini API and return the response""" self.response = await self.client.send(self.gemini_request) response_string = self.response.text @@ -492,10 +528,12 @@ class InstrumentedGeminiResponse(InstrumentedResponse): response_string = json.dumps(self.response_json) response_code = self.response.status_code - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: # Block on the guardrails check guardrails_execution_result = await get_guardrails_check_result( - self.context, self.response_json + self.context, + action=GuardrailAction.BLOCK, + response_json=self.response_json, ) if guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -539,7 +577,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): async def handle_non_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ) -> Response: diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index e3cec41..ada5309 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -14,8 +14,13 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.request_context_data import RequestContextData -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from common.guardrails import GuardrailAction +from common.request_context import RequestContext +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -72,10 +77,17 @@ async def openai_chat_completions_gateway( headers=headers, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + dataset_guardrails=dataset_guardrails, config=config, ) asyncio.create_task(preload_guardrails(context)) @@ -92,19 +104,20 @@ async def openai_chat_completions_gateway( class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): """ - Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + Does a streaming OpenAI completion request at the core, but also checks guardrails + before (concurrent) and after the request. """ def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.open_ai_request: httpx.Request = open_ai_request @@ -131,10 +144,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.tool_call_mapping_by_index = {} async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -164,6 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) async def on_chunk(self, chunk): + """Processes each chunk of the stream and checks guardrails at the end of the stream""" # process and check each chunk chunk_text = chunk.decode().strip() if not chunk_text: @@ -179,14 +198,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) # check guardrails at the end of the stream (on the '[DONE]' SSE chunk.) - if ( - "data: [DONE]" in chunk_text - and self.context.config - and self.context.config.guardrails - ): + if "data: [DONE]" in chunk_text and self.context.dataset_guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -214,10 +231,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) async def event_generator(self): - """ - Actual OpenAI stream response. - """ - + """Actual OpenAI stream response.""" response = await self.client.send(self.open_ai_request, stream=True) if response.status_code != 200: error_content = await response.aread() @@ -234,7 +248,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): async def handle_stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ) -> Response: @@ -389,7 +403,7 @@ def update_existing_choice_with_delta( def create_metadata( - context: RequestContextData, merged_response: dict[str, Any] + context: RequestContext, merged_response: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = { @@ -409,7 +423,7 @@ def create_metadata( async def push_to_explorer( - context: RequestContextData, + context: RequestContext, merged_response: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -437,18 +451,28 @@ async def push_to_explorer( async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] | None = None + context: RequestContext, + action: GuardrailAction, + json_response: dict[str, Any] | None = None, ) -> dict[str, Any]: """Get the guardrails check result""" - messages = list(context.request_json.get("messages", [])) + # Determine which guardrails to apply based on the action + guardrails = ( + context.dataset_guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.dataset_guardrails.blocking_guardrails + ) + if not guardrails: + return {} + messages = list(context.request_json.get("messages", [])) if json_response is not None: messages += [choice["message"] for choice in json_response.get("choices", [])] # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=messages, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result @@ -456,19 +480,20 @@ async def get_guardrails_check_result( class InstrumentedOpenAIResponse(InstrumentedResponse): """ - Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + Does an OpenAI completion request at the core, but also checks guardrails + before (concurrent) and after the request. """ def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.open_ai_request: httpx.Request = open_ai_request @@ -480,11 +505,14 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.guardrails_execution_result: Optional[dict] = None async def on_start(self): - """Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)""" - if self.context.config and self.context.config.guardrails: + """ + Checks guardrails in a pipelined fashion, before processing + the first chunk (for input guardrailing) + """ + if self.context.dataset_guardrails: # block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context + self.context, action=GuardrailAction.BLOCK ) if self.guardrails_execution_result.get("errors", []): # Push annotated trace to the explorer - don't block on its response @@ -542,7 +570,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): async def on_end(self): """Postprocesses the OpenAI response and potentially replace it with a guardrails error.""" - # these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed) + # these two request outputs are guaranteed to be available by the time we reach + # this point (after self.request() was executed) # nevertheless, we check for them to avoid any potential issues assert ( self.response is not None @@ -555,10 +584,12 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): response_code = self.response.status_code # if we have guardrails, check the response - if self.context.config and self.context.config.guardrails: + if self.context.dataset_guardrails: # run guardrails again, this time on request + response self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.json_response + self.context, + action=GuardrailAction.BLOCK, + json_response=self.json_response, ) if self.guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -601,7 +632,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): async def handle_non_stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ) -> Response: diff --git a/run.sh b/run.sh index f666f96..72ad68e 100755 --- a/run.sh +++ b/run.sh @@ -93,6 +93,11 @@ integration_tests() { fi echo "File successfully downloaded: $FILE" + if [[ -z "$INVARIANT_API_KEY" ]]; then + echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests." + exit 1 + fi + TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py" if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then diff --git a/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py b/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py index a2e3db0..1597846 100644 --- a/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py @@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header( """Test the Anthropic gateway with Invariant key in the Anthropic key""" anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - { - "ANTHROPIC_API_KEY": anthropic_api_key - + ";invariant-auth=" - }, + {"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix}, ): client = anthropic.Anthropic( http_client=Client(), diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 89591ea..560ada0 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -26,7 +26,7 @@ class WeatherAgent: def __init__(self, gateway_url, push_to_explorer): self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY") self.client = anthropic.Anthropic( http_client=Client( headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, diff --git a/tests/integration/anthropic/test_anthropic_without_tool_call.py b/tests/integration/anthropic/test_anthropic_without_tool_call.py index a397c25..c47eaba 100644 --- a/tests/integration/anthropic/test_anthropic_without_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_without_tool_call.py @@ -26,7 +26,7 @@ async def test_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY") client = anthropic.Anthropic( http_client=Client( @@ -91,7 +91,7 @@ async def test_streaming_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY") client = anthropic.Anthropic( http_client=Client( diff --git a/tests/integration/gemini/test_generate_content_with_tool_calls.py b/tests/integration/gemini/test_generate_content_with_tool_calls.py index bcf838e..7d2486e 100644 --- a/tests/integration/gemini/test_generate_content_with_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_with_tool_calls.py @@ -151,7 +151,7 @@ async def test_generate_content_with_tool_call( if push_to_explorer else f"{gateway_url}/api/v1/gateway/gemini", "headers": { - "invariant-authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests }, ) diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index d16797a..cf3f42f 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -36,7 +36,7 @@ async def test_generate_content( if push_to_explorer else f"{gateway_url}/api/v1/gateway/gemini", "headers": { - "invariant-authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests }, ) @@ -123,7 +123,7 @@ async def test_generate_content_with_image( if push_to_explorer else f"{gateway_url}/api/v1/gateway/gemini", "headers": { - "invariant-authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests }, ) @@ -181,9 +181,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( """Test the generate content gateway calls with the Invariant API Key in the Gemini Key header.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" gemini_api_key = os.getenv("GEMINI_API_KEY") + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - {"GEMINI_API_KEY": gemini_api_key + ";invariant-auth="}, + {"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix}, ): client = genai.Client( api_key=os.getenv("GEMINI_API_KEY"), diff --git a/tests/integration/open_ai/test_chat_with_tool_call.py b/tests/integration/open_ai/test_chat_with_tool_call.py index 69ad6af..5e96db9 100644 --- a/tests/integration/open_ai/test_chat_with_tool_call.py +++ b/tests/integration/open_ai/test_chat_with_tool_call.py @@ -32,7 +32,7 @@ async def test_chat_completion_with_tool_call_without_streaming( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" @@ -150,7 +150,7 @@ async def test_chat_completion_with_tool_call_with_streaming( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" diff --git a/tests/integration/open_ai/test_chat_without_tool_calls.py b/tests/integration/open_ai/test_chat_without_tool_calls.py index e9134b3..a8745fb 100644 --- a/tests/integration/open_ai/test_chat_without_tool_calls.py +++ b/tests/integration/open_ai/test_chat_without_tool_calls.py @@ -34,7 +34,7 @@ async def test_chat_completion( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" @@ -107,7 +107,7 @@ async def test_chat_completion_with_image( client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" @@ -189,9 +189,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header( """Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" openai_api_key = os.getenv("OPENAI_API_KEY") + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - {"OPENAI_API_KEY": openai_api_key + ";invariant-auth="}, + {"OPENAI_API_KEY": openai_api_key + invariant_key_suffix}, ): client = OpenAI( http_client=Client(), @@ -252,7 +253,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream): client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai", From 750c83d3f88025394afcdf503d2dc718bf74efd5 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 14:41:18 +0200 Subject: [PATCH 3/7] Add calls to execute logging guardrails before pushing to explorer. --- gateway/integrations/guardrails.py | 23 -------- gateway/routes/anthropic.py | 59 ++++++++++--------- gateway/routes/gemini.py | 13 +++- gateway/routes/open_ai.py | 41 ++++++++----- ...est_generate_content_without_tool_calls.py | 6 +- 5 files changed, 73 insertions(+), 69 deletions(-) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 412418e..b7377c3 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -1,7 +1,6 @@ """Utility functions for Guardrails execution.""" import asyncio -import json import os import time from typing import Any, Dict, List @@ -351,28 +350,6 @@ async def check_guardrails( async with httpx.AsyncClient() as client: url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") try: - print( - "Hello there this is the request to guardrails: ", - json.dumps( - { - "messages": messages, - "policies": [g.content for g in guardrails], - }, - indent=2, - ), - flush=True, - ) - print( - "Hello there this is the request to guardrails: ", - json.dumps( - { - "Authorization": invariant_authorization, - "Accept": "application/json", - }, - indent=2, - ), - flush=True, - ) result = await client.post( f"{url}/api/v1/policy/check/batch", json={ diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 9905258..09ce85b 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -120,7 +120,7 @@ def create_metadata( def combine_request_and_response_messages( - context: RequestContext, json_response: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ): """Combine the request and response messages""" messages = [] @@ -129,13 +129,13 @@ def combine_request_and_response_messages( {"role": "system", "content": context.request_json.get("system")} ) messages.extend(context.request_json.get("messages", [])) - if len(json_response) > 0: - messages.append(json_response) + if len(response_json) > 0: + messages.append(response_json) return messages async def get_guardrails_check_result( - context: RequestContext, action: GuardrailAction, json_response: dict[str, Any] + context: RequestContext, action: GuardrailAction, response_json: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" # Determine which guardrails to apply based on the action @@ -147,7 +147,7 @@ async def get_guardrails_check_result( if not guardrails: return {} - messages = combine_request_and_response_messages(context, json_response) + messages = combine_request_and_response_messages(context, response_json) converted_messages = convert_anthropic_to_invariant_message_format(messages) # Block on the guardrails check @@ -170,10 +170,22 @@ async def push_to_explorer( guardrails_execution_result.get("errors", []) ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []) + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + # Combine the messages from the request body and Anthropic response messages = combine_request_and_response_messages(context, merged_response) - converted_messages = convert_anthropic_to_invariant_message_format(messages) + _ = await push_trace( dataset_name=context.dataset_name, messages=[converted_messages], @@ -200,7 +212,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): # response data self.response: Optional[httpx.Response] = None self.response_string: Optional[str] = None - self.json_response: Optional[dict[str, Any]] = None + self.response_json: Optional[dict[str, Any]] = None # guardrailing response (if any) self.guardrails_execution_result = {} @@ -209,7 +221,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" if self.context.dataset_guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, action=GuardrailAction.BLOCK, json_response={} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -243,10 +255,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): ) async def request(self): + """Make the request to the Anthropic API.""" self.response = await self.client.send(self.anthropic_request) try: - json_response = self.response.json() + response_json = self.response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, @@ -255,11 +268,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): if self.response.status_code != 200: raise HTTPException( status_code=self.response.status_code, - detail=json_response.get("error", "Unknown error from Anthropic"), + detail=response_json.get("error", "Unknown error from Anthropic"), ) - self.json_response = json_response - self.response_string = json.dumps(json_response) + self.response_json = response_json + self.response_string = json.dumps(response_json) return self._make_response( content=self.response_string, @@ -284,7 +297,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" # ensure the response data is available assert self.response is not None, "response is None" - assert self.json_response is not None, "json_response is None" + assert self.response_json is not None, "response_json is None" assert self.response_string is not None, "response_string is None" if self.context.dataset_guardrails: @@ -292,12 +305,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.json_response, - ) - print( - "Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ", - guardrails_execution_result, - flush=True, + response_json=self.response_json, ) if guardrails_execution_result.get("errors", []): guardrail_response_string = json.dumps( @@ -313,7 +321,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, guardrails_execution_result, ) ) @@ -330,7 +338,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): # Push to Explorer - don't block on its response asyncio.create_task( push_to_explorer( - self.context, self.json_response, guardrails_execution_result + self.context, self.response_json, guardrails_execution_result ) ) @@ -378,7 +386,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -440,12 +448,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, - ) - print( - "Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ", - self.guardrails_execution_result, - flush=True, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 1e21b90..b390461 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -290,7 +290,6 @@ async def stream_response( async def event_generator(): async for chunk in response.instrumented_event_generator(): yield chunk - print("chunk", chunk) return StreamingResponse( event_generator(), @@ -408,6 +407,18 @@ async def push_to_explorer( guardrails_execution_result.get("errors", []) ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=response_json, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []) + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + converted_requests = convert_request(context.request_json) converted_responses = convert_response(response_json) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index ada5309..f4a20f4 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -152,7 +152,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -203,7 +203,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.merged_response, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -438,6 +438,19 @@ async def push_to_explorer( not in FINISH_REASON_TO_PUSH_TRACE ): annotations = create_annotations_from_guardrails_errors(guardrails_errors) + + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []) + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + # Combine the messages from the request body and the choices from the OpenAI response messages = list(context.request_json.get("messages", [])) messages += [choice["message"] for choice in merged_response.get("choices", [])] @@ -453,7 +466,7 @@ async def push_to_explorer( async def get_guardrails_check_result( context: RequestContext, action: GuardrailAction, - json_response: dict[str, Any] | None = None, + response_json: dict[str, Any] | None = None, ) -> dict[str, Any]: """Get the guardrails check result""" # Determine which guardrails to apply based on the action @@ -466,8 +479,8 @@ async def get_guardrails_check_result( return {} messages = list(context.request_json.get("messages", [])) - if json_response is not None: - messages += [choice["message"] for choice in json_response.get("choices", [])] + if response_json is not None: + messages += [choice["message"] for choice in response_json.get("choices", [])] # Block on the guardrails check guardrails_execution_result = await check_guardrails( @@ -499,7 +512,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): # request outputs self.response: Optional[httpx.Response] = None - self.json_response: Optional[dict[str, Any]] = None + self.response_json: Optional[dict[str, Any]] = None # guardrailing output (if any) self.guardrails_execution_result: Optional[dict] = None @@ -545,7 +558,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.response = await self.client.send(self.open_ai_request) try: - self.json_response = self.response.json() + self.response_json = self.response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, @@ -554,10 +567,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): if self.response.status_code != 200: raise HTTPException( status_code=self.response.status_code, - detail=self.json_response.get("error", "Unknown error from OpenAI API"), + detail=self.response_json.get("error", "Unknown error from OpenAI API"), ) - response_string = json.dumps(self.json_response) + response_string = json.dumps(self.response_json) response_code = self.response.status_code return Response( @@ -577,8 +590,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.response is not None ), "on_end called before 'self.response' was available" assert ( - self.json_response is not None - ), "on_end called before 'self.json_response' was available" + self.response_json is not None + ), "on_end called before 'self.response_json' was available" # extract original response status code response_code = self.response.status_code @@ -589,7 +602,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, - json_response=self.json_response, + response_json=self.response_json, ) if self.guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -605,7 +618,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, self.guardrails_execution_result, ) ) @@ -624,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, # include any guardrailing errors if available self.guardrails_execution_result, ) diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index cf3f42f..84ed352 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -195,14 +195,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( chat_response = client.models.generate_content( model="gemini-2.0-flash", - contents="What is the capital of Spain?", + contents="What is the capital of Denmark?", config={ "maxOutputTokens": 100, }, ) # Verify the chat response - assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper() + assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper() expected_assistant_message = chat_response.candidates[0].content.parts[0].text # Wait for the trace to be saved @@ -229,7 +229,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( assert trace["messages"] == [ { "role": "user", - "content": [{"text": "What is the capital of Spain?", "type": "text"}], + "content": [{"text": "What is the capital of Denmark?", "type": "text"}], }, { "role": "assistant", From eced3755b2f09bbf2ce7dc95a8ce56d4378cc2b4 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 15:15:32 +0200 Subject: [PATCH 4/7] Refactor tests. --- run.sh | 2 +- .../test_anthropic_with_tool_call.py | 13 ++--- .../test_anthropic_without_tool_call.py | 28 +++------- tests/integration/docker-compose.test.yml | 1 + .../test_generate_content_with_tool_calls.py | 16 +----- ...est_generate_content_without_tool_calls.py | 27 ++------- .../guardrails/test_guardrails_anthropic.py | 44 +++++---------- .../guardrails/test_guardrails_gemini.py | 48 +++++----------- .../guardrails/test_guardrails_open_ai.py | 38 ++----------- .../open_ai/test_chat_with_tool_call.py | 28 ++-------- .../open_ai/test_chat_without_tool_calls.py | 25 ++------- ...> integration_test_guardrails_via_file.py} | 6 +- tests/integration/utils.py | 56 +++++++++++++++++++ 13 files changed, 121 insertions(+), 211 deletions(-) rename tests/integration/resources/guardrails/{find_capital_guardrails.py => integration_test_guardrails_via_file.py} (87%) create mode 100644 tests/integration/utils.py diff --git a/run.sh b/run.sh index 72ad68e..251f68d 100755 --- a/run.sh +++ b/run.sh @@ -98,7 +98,7 @@ integration_tests() { exit 1 fi - TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py" + TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/integration_test_guardrails_via_file.py" if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then TEST_GUARDRAILS_FILE_PATH=$(realpath "$TEST_GUARDRAILS_FILE_PATH") diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 560ada0..abc3937 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -12,10 +12,11 @@ from typing import Dict, List # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_anthropic_client + import anthropic import pytest import requests -from httpx import Client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -26,14 +27,8 @@ class WeatherAgent: def __init__(self, gateway_url, push_to_explorer): self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY") - self.client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{gateway_url}/api/v1/gateway/{self.dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + self.client = get_anthropic_client( + gateway_url, push_to_explorer, self.dataset_name ) self.get_weather_function = { "name": "get_weather", diff --git a/tests/integration/anthropic/test_anthropic_without_tool_call.py b/tests/integration/anthropic/test_anthropic_without_tool_call.py index c47eaba..280341f 100644 --- a/tests/integration/anthropic/test_anthropic_without_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_without_tool_call.py @@ -8,10 +8,10 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -import anthropic +from utils import get_anthropic_client + import pytest import requests -from httpx import Client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -26,15 +26,10 @@ async def test_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY") - - client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) cities = ["zurich", "new york", "london"] @@ -91,16 +86,7 @@ async def test_streaming_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY") - - client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", - ) + client = get_anthropic_client(gateway_url, push_to_explorer, dataset_name) cities = ["zurich", "new york", "london"] queries = [ diff --git a/tests/integration/docker-compose.test.yml b/tests/integration/docker-compose.test.yml index 52d481e..8c1afe3 100644 --- a/tests/integration/docker-compose.test.yml +++ b/tests/integration/docker-compose.test.yml @@ -60,6 +60,7 @@ services: app-api: container_name: invariant-gateway-test-explorer-app-api image: ghcr.io/invariantlabs-ai/explorer/app-api:latest + pull_policy: always platform: linux/amd64 depends_on: database: diff --git a/tests/integration/gemini/test_generate_content_with_tool_calls.py b/tests/integration/gemini/test_generate_content_with_tool_calls.py index 7d2486e..96c4bd7 100644 --- a/tests/integration/gemini/test_generate_content_with_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_with_tool_calls.py @@ -8,9 +8,10 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_gemini_client + import pytest import requests -from google import genai from google.genai import types # Pytest plugins @@ -143,18 +144,7 @@ async def test_generate_content_with_tool_call( without streaming. """ dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - }, - ) + client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) request = { "model": "gemini-2.0-flash", diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index 84ed352..27036a2 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -10,6 +10,8 @@ from unittest.mock import patch # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_gemini_client + import pytest import PIL.Image import requests @@ -29,17 +31,8 @@ async def test_generate_content( ): """Test the generate content gateway calls without tool calling.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - }, - ) + client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) + request = { "model": "gemini-2.0-flash", "contents": "What is the capital of France?", @@ -115,18 +108,8 @@ async def test_generate_content_with_image( ): """Test that generate content gateway calls work with image.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - }, - ) image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png" image = PIL.Image.open(image_path) diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index 173e5ab..aaa349a 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -8,6 +8,8 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_anthropic_client + import pytest import requests from httpx import Client @@ -32,16 +34,10 @@ async def test_message_content_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - - client = Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -161,16 +157,10 @@ async def test_tool_call_guardrail_from_file( } dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - - client = Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) if not do_stream: @@ -255,16 +245,10 @@ async def test_input_from_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - - client = Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index c463186..e452284 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -8,9 +8,10 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_gemini_client + import pytest import requests -from httpx import Client from google import genai # Pytest plugins @@ -30,17 +31,10 @@ async def test_message_content_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - }, + client = get_gemini_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -141,17 +135,10 @@ async def test_tool_call_guardrail_from_file( ) dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - }, + client = get_gemini_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -244,17 +231,10 @@ async def test_input_from_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - }, + client = get_gemini_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index acc2f67..c15989a 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -8,6 +8,8 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_open_ai_client + import pytest import requests from httpx import Client @@ -30,17 +32,7 @@ async def test_message_content_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) request = { "model": "gpt-4o", @@ -161,17 +153,7 @@ async def test_tool_call_guardrail_from_file( } dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) if not do_stream: with pytest.raises(BadRequestError) as exc_info: @@ -259,17 +241,7 @@ async def test_input_from_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) request = { "model": "gpt-4o", diff --git a/tests/integration/open_ai/test_chat_with_tool_call.py b/tests/integration/open_ai/test_chat_with_tool_call.py index 5e96db9..ba928c6 100644 --- a/tests/integration/open_ai/test_chat_with_tool_call.py +++ b/tests/integration/open_ai/test_chat_with_tool_call.py @@ -9,10 +9,10 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_open_ai_client + import pytest import requests -from httpx import Client -from openai import OpenAI # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -28,17 +28,7 @@ async def test_chat_completion_with_tool_call_without_streaming( without streaming. """ dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) chat_response = client.chat.completions.create( model="gpt-4o", @@ -146,17 +136,7 @@ async def test_chat_completion_with_tool_call_with_streaming( while streaming. """ dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) chat_response = client.chat.completions.create( model="gpt-4o", diff --git a/tests/integration/open_ai/test_chat_without_tool_calls.py b/tests/integration/open_ai/test_chat_without_tool_calls.py index a8745fb..69d7816 100644 --- a/tests/integration/open_ai/test_chat_without_tool_calls.py +++ b/tests/integration/open_ai/test_chat_without_tool_calls.py @@ -11,6 +11,8 @@ from unittest.mock import patch # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_open_ai_client + import pytest import requests from httpx import Client @@ -30,17 +32,7 @@ async def test_chat_completion( ): """Test the chat completions gateway calls without tool calling.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) chat_response = client.chat.completions.create( model="gpt-4o", @@ -103,17 +95,8 @@ async def test_chat_completion_with_image( ): """Test the chat completions gateway works with image.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png" with image_path.open("rb") as image_file: base64_image = base64.b64encode(image_file.read()).decode("utf-8") diff --git a/tests/integration/resources/guardrails/find_capital_guardrails.py b/tests/integration/resources/guardrails/integration_test_guardrails_via_file.py similarity index 87% rename from tests/integration/resources/guardrails/find_capital_guardrails.py rename to tests/integration/resources/guardrails/integration_test_guardrails_via_file.py index 282720f..fe31990 100644 --- a/tests/integration/resources/guardrails/find_capital_guardrails.py +++ b/tests/integration/resources/guardrails/integration_test_guardrails_via_file.py @@ -1,15 +1,15 @@ -# For Anthropic and OpenAI tests: +# For Anthropic and OpenAI tests raise "Madrid detected in the response" if: (msg: Message) "Madrid" in msg.content -# For Gemini tests: +# For Gemini tests raise "Dublin detected in the response" if: (msg: Message) (chunk: str) in text(msg.content) "Dublin" in chunk -# For all provider tests: +# For all provider tests raise "get_capital is called with Germany as argument" if: (call: ToolCall) call is tool:get_capital diff --git a/tests/integration/utils.py b/tests/integration/utils.py new file mode 100644 index 0000000..39c82b6 --- /dev/null +++ b/tests/integration/utils.py @@ -0,0 +1,56 @@ +"""Common utilities for integration tests.""" + +import os +from httpx import Client +from openai import OpenAI +from google import genai +from anthropic import Anthropic + + +def get_open_ai_client( + gateway_url: str, push_to_explorer: bool, dataset_name: str +) -> OpenAI: + """Create an OpenAI client for integration tests.""" + return OpenAI( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/openai", + ) + + +def get_anthropic_client( + gateway_url: str, push_to_explorer: bool, dataset_name: str +) -> Anthropic: + """Create an Anthropic client for integration tests.""" + return Anthropic( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/anthropic", + ) + + +def get_gemini_client( + gateway_url: str, push_to_explorer: bool, dataset_name: str +) -> genai.Client: + """Create a Gemini client for integration tests.""" + return genai.Client( + api_key=os.getenv("GEMINI_API_KEY"), + http_options={ + "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/gemini", + "headers": { + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + }, + ) From 9aef873a74d1ac6bac858dbc5150661f4cb47406 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 1 Apr 2025 15:42:01 +0200 Subject: [PATCH 5/7] Correct header name before calling explorer to fetch guardrails. --- gateway/integrations/explorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index 29e0710..09b32b1 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -111,7 +111,7 @@ async def fetch_guardrails_from_explorer( client = httpx.AsyncClient( base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), headers={ - "Invariant-Authorization": invariant_authorization, + "Authorization": invariant_authorization, }, ) From 55f0f741c0608911a7a970ca21ad4752fe47af11 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 2 Apr 2025 07:19:56 +0200 Subject: [PATCH 6/7] Add tests for guardrails integration with explorer. --- gateway/integrations/explorer.py | 7 +- gateway/routes/anthropic.py | 4 +- gateway/routes/gemini.py | 4 +- gateway/routes/open_ai.py | 8 +- .../guardrails/test_guardrails_anthropic.py | 151 ++++++++++++++++- .../guardrails/test_guardrails_gemini.py | 143 +++++++++++++++- .../guardrails/test_guardrails_open_ai.py | 152 +++++++++++++++++- tests/integration/utils.py | 51 +++++- 8 files changed, 503 insertions(+), 17 deletions(-) diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index 09b32b1..22a104d 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -14,7 +14,7 @@ DEFAULT_API_URL = "https://explorer.invariantlabs.ai" def create_annotations_from_guardrails_errors( - guardrails_errors: List[dict], + guardrails_errors: List[dict], action: str = "block" ) -> List[AnnotationCreate]: """Create Explorer annotations from the guardrails errors.""" annotations = [] @@ -48,7 +48,10 @@ def create_annotations_from_guardrails_errors( AnnotationCreate( content=content, address=r, - extra_metadata={"source": "guardrails-error"}, + extra_metadata={ + "source": "guardrails-error", + "guardrail-action": action, + }, ) ) return annotations diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 09ce85b..24cf097 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -167,7 +167,7 @@ async def push_to_explorer( """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} annotations = create_annotations_from_guardrails_errors( - guardrails_execution_result.get("errors", []) + guardrails_execution_result.get("errors", []), action="block" ) # Execute the logging guardrails before pushing to Explorer @@ -177,7 +177,7 @@ async def push_to_explorer( response_json=merged_response, ) logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []) + logging_guardrails_execution_result.get("errors", []), action="log" ) # Update the annotations with the logging guardrails annotations.extend(logging_annotations) diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index b390461..6d4a409 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -404,7 +404,7 @@ async def push_to_explorer( """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} annotations = create_annotations_from_guardrails_errors( - guardrails_execution_result.get("errors", []) + guardrails_execution_result.get("errors", []), action="block" ) # Execute the logging guardrails before pushing to Explorer @@ -414,7 +414,7 @@ async def push_to_explorer( response_json=response_json, ) logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []) + logging_guardrails_execution_result.get("errors", []), action="log" ) # Update the annotations with the logging guardrails annotations.extend(logging_annotations) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index f4a20f4..f929a2c 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -221,7 +221,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): # push will happen in on_end async def on_end(self): - """Sends full merged response to the exploree.""" + """Sends full merged response to the explorer.""" # don't block on the response from explorer (.create_task) if self.context.dataset_name: asyncio.create_task( @@ -437,7 +437,9 @@ async def push_to_explorer( and merged_response["choices"][0].get("finish_reason") not in FINISH_REASON_TO_PUSH_TRACE ): - annotations = create_annotations_from_guardrails_errors(guardrails_errors) + annotations = create_annotations_from_guardrails_errors( + guardrails_errors, action="block" + ) # Execute the logging guardrails before pushing to Explorer logging_guardrails_execution_result = await get_guardrails_check_result( @@ -446,7 +448,7 @@ async def push_to_explorer( response_json=merged_response, ) logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []) + logging_guardrails_execution_result.get("errors", []), action="log" ) # Update the annotations with the logging guardrails annotations.extend(logging_annotations) diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index aaa349a..f61d9e6 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -8,12 +8,11 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_anthropic_client +from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset import pytest import requests -from httpx import Client -from anthropic import Anthropic, APIStatusError, BadRequestError +from anthropic import APIStatusError, BadRequestError # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -316,3 +315,149 @@ async def test_input_from_guardrail_from_file( == "Users must not mention the magic phrase 'Fight Club'" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + +@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") +@pytest.mark.parametrize("do_stream", [True, False]) +async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + client = get_anthropic_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content', + action="log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Ask about the capital of Spain + # This should not be blocked by the guardrails from the explorer when we push to explorer + # because the file based guardrails are overridden by the explorer guardrails + spain_request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [{"role": "user", "content": "What is the capital of Spain?"}], + "max_tokens": 100, + } + if not do_stream: + chat_response = client.messages.create( + **spain_request, + stream=False, + ) + + assert "Madrid" in chat_response.content[0].text + else: + chat_response = client.messages.create( + **spain_request, + stream=True, + ) + + merged_content = "" + for chunk in chat_response: + if chunk.type == "content_block_delta": + merged_content += chunk.delta.text + assert "Madrid" in merged_content + + # Ask about Shrek + # This should be blocked by the guardrails from the explorer + user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details." + shrek_request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.messages.create( + **shrek_request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + else: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create( + **shrek_request, + stream=True, + ) + + for _ in chat_response: + pass + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 2 + trace_id = traces[1]["id"] + + # Fetch the second trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 2 + assert ( + annotations[0]["content"] == "ogre detected in response" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + ) + assert ( + annotations[1]["content"] == "Fiona detected in response" + and annotations[1]["extra_metadata"]["source"] == "guardrails-error" + and annotations[1]["extra_metadata"]["guardrail-action"] == "log" + ) diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index e452284..6fc0945 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -8,7 +8,7 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_gemini_client +from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset import pytest import requests @@ -303,6 +303,147 @@ async def test_input_from_guardrail_from_file( ) +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize("do_stream", [True, False]) +async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content', + action="log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Ask about the capital of Spain + # This should not be blocked by the guardrails from the explorer when we push to explorer + # because the file based guardrails are overridden by the explorer guardrails + spain_request = { + "model": "gemini-2.0-flash", + "contents": "What is the capital of Spain?", + "config": { + "maxOutputTokens": 100, + }, + } + if not do_stream: + chat_response = client.models.generate_content(**spain_request) + + assert "Madrid" in chat_response.candidates[0].content.parts[0].text + else: + chat_response = client.models.generate_content_stream(**spain_request) + + merged_content = "" + for chunk in chat_response: + if ( + chunk.candidates + and chunk.candidates[0].content + and chunk.candidates[0].content.parts + ): + for text_part in chunk.candidates[0].content.parts: + merged_content += text_part.text + assert "Madrid" in merged_content + + # Ask about Shrek + # This should be blocked by the guardrails from the explorer + user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details." + shrek_request = { + "model": "gemini-2.0-flash", + "contents": user_prompt, + "config": { + "maxOutputTokens": 100, + }, + } + if not do_stream: + with pytest.raises(genai.errors.ClientError) as exc_info: + client.models.generate_content(**shrek_request) + + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + else: + response = client.models.generate_content_stream(**shrek_request) + + assert_is_streamed_refusal( + response, + [ + "[Invariant] The response did not pass the guardrails", + "ogre detected in response", + ], + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 2 + trace_id = traces[1]["id"] + + # Fetch the second trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 + assert trace["messages"][0] == { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt, + } + ], + } + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 2 + assert ( + annotations[0]["content"] == "ogre detected in response" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + ) + assert ( + annotations[1]["content"] == "Fiona detected in response" + and annotations[1]["extra_metadata"]["source"] == "guardrails-error" + and annotations[1]["extra_metadata"]["guardrail-action"] == "log" + ) + + def is_refusal(chunk): return ( len(chunk.candidates) == 1 diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index c15989a..b0c6b24 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -8,12 +8,11 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_open_ai_client +from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset import pytest import requests -from httpx import Client -from openai import OpenAI, BadRequestError, APIError +from openai import BadRequestError, APIError # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -321,3 +320,150 @@ async def test_input_from_guardrail_from_file( == "Users must not mention the magic phrase 'Fight Club'" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize("do_stream", [True, False]) +async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content', + action="log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Ask about the capital of Spain + # This should not be blocked by the guardrails from the explorer when we push to explorer + # because the file based guardrails are overridden by the explorer guardrails + spain_request = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "What is the capital of Spain?"}], + "max_tokens": 100, + } + if not do_stream: + chat_response = client.chat.completions.create( + **spain_request, + stream=False, + ) + + assert "Madrid" in chat_response.choices[0].message.content + else: + chat_response = client.chat.completions.create( + **spain_request, + stream=True, + ) + + merged_content = "" + for chunk in chat_response: + if chunk.choices[0].delta.content: + merged_content += chunk.choices[0].delta.content + assert "Madrid" in merged_content + + # Ask about Shrek + # This should be blocked by the guardrails from the explorer + user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details." + shrek_request = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **shrek_request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + else: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **shrek_request, + stream=True, + ) + + for _ in chat_response: + pass + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 2 + trace_id = traces[1]["id"] + + # Fetch the second trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 2 + assert ( + annotations[0]["content"] == "ogre detected in response" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + ) + assert ( + annotations[1]["content"] == "Fiona detected in response" + and annotations[1]["extra_metadata"]["source"] == "guardrails-error" + and annotations[1]["extra_metadata"]["guardrail-action"] == "log" + ) diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 39c82b6..6df9ce9 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -1,7 +1,10 @@ """Common utilities for integration tests.""" import os -from httpx import Client +import uuid +from typing import Any, Dict, Literal, Optional + +from httpx import AsyncClient, Client from openai import OpenAI from google import genai from anthropic import Anthropic @@ -54,3 +57,49 @@ def get_gemini_client( }, }, ) + + +async def create_dataset( + explorer_api_url: str, + invariant_authorization: str, + dataset_name: Optional[str] = None, +) -> Dict[str, Any]: + """Create a dataset in the Explorer API.""" + client = Client(base_url=explorer_api_url) + response = client.post( + "/api/v1/dataset/create", + json={"name": dataset_name if dataset_name else f"test-dataset-{uuid.uuid4()}"}, + headers={"Authorization": invariant_authorization}, + timeout=5, + ) + if response.status_code != 200: + raise ValueError( + f"Failed to create dataset: {response.status_code}, {response.text}" + ) + return response.json() + + +async def add_guardrail_to_dataset( + explorer_api_url: str, + dataset_id: str, + policy: str, + action: Literal["block", "log"], + invariant_authorization: str, +) -> Dict[str, Any]: + """Add a guardrail to a dataset.""" + client = Client(base_url=explorer_api_url) + response = client.post( + f"/api/v1/dataset/{dataset_id}/policy", + json={ + "action": action, + "policy": policy, + "name": f"test-guardrail-{uuid.uuid4()}", + }, + headers={"Authorization": invariant_authorization}, + timeout=5, + ) + if response.status_code != 200: + raise ValueError( + f"Failed to add guardrail: {response.status_code}, {response.text}" + ) + return response.json() From f3a56e1e43faedb6e30549684b1ce8dfe783a55c Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 2 Apr 2025 11:25:39 +0200 Subject: [PATCH 7/7] Add preguardrailing tests for guardrails pulled from explorer. --- gateway/routes/gemini.py | 1 + gateway/routes/open_ai.py | 33 +++-- .../guardrails/test_guardrails_anthropic.py | 129 +++++++++++++++++- .../guardrails/test_guardrails_gemini.py | 112 +++++++++++++++ .../guardrails/test_guardrails_open_ai.py | 126 ++++++++++++++++- 5 files changed, 379 insertions(+), 22 deletions(-) diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 6d4a409..2643125 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -536,6 +536,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): ) async def on_end(self): + """Runs when the request ends.""" response_string = json.dumps(self.response_json) response_code = self.response.status_code diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index f929a2c..ff565e0 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -432,27 +432,26 @@ async def push_to_explorer( # or if the guardrails check returned errors. guardrails_execution_result = guardrails_execution_result or {} guardrails_errors = guardrails_execution_result.get("errors", []) - if guardrails_errors or not ( + annotations = create_annotations_from_guardrails_errors( + guardrails_errors, action="block" + ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []), action="log" + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + + if annotations or not ( merged_response.get("choices") and merged_response["choices"][0].get("finish_reason") not in FINISH_REASON_TO_PUSH_TRACE ): - annotations = create_annotations_from_guardrails_errors( - guardrails_errors, action="block" - ) - - # Execute the logging guardrails before pushing to Explorer - logging_guardrails_execution_result = await get_guardrails_check_result( - context, - action=GuardrailAction.LOG, - response_json=merged_response, - ) - logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []), action="log" - ) - # Update the annotations with the logging guardrails - annotations.extend(logging_annotations) - # Combine the messages from the request body and the choices from the OpenAI response messages = list(context.request_json.get("messages", [])) messages += [choice["message"] for choice in merged_response.get("choices", [])] diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index f61d9e6..035a845 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -316,7 +316,10 @@ async def test_input_from_guardrail_from_file( and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) @pytest.mark.parametrize("do_stream", [True, False]) async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): """Test that the guardrails from the explorer work.""" @@ -461,3 +464,127 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["guardrail-action"] == "log" ) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + client = get_anthropic_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if is_block_action: + if do_stream: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create( + **request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + else: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.messages.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + + else: + if do_stream: + _ = client.messages.create( + **request, + stream=True, + ) + else: + _ = client.messages.create( + **request, + stream=False, + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 if not is_block_action else 1 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index 6fc0945..b3ac35e 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -444,6 +444,118 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s ) +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "gemini-2.0-flash", + "contents": user_prompt, + "config": { + "maxOutputTokens": 100, + }, + } + if is_block_action: + if do_stream: + chat_response = client.models.generate_content_stream(**request) + + assert_is_streamed_refusal( + chat_response, + [ + "[Invariant] The request did not pass the guardrails", + "pun detected in user message", + ], + ) + else: + with pytest.raises(genai.errors.ClientError) as exc_info: + chat_response = client.models.generate_content(**request) + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + else: + if do_stream: + response = client.models.generate_content_stream(**request) + for _ in response: + pass + else: + _ = client.models.generate_content(**request) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 if not is_block_action else 1 + assert trace["messages"][0] == { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt, + } + ], + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) + + def is_refusal(chunk): return ( len(chunk.candidates) == 1 diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index b0c6b24..6031778 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -412,15 +412,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s **shrek_request, stream=True, ) - for _ in chat_response: pass + assert "[Invariant] The response did not pass the guardrails" in str( exc_info.value ) - # Only the block guardrail should be triggered here - assert "ogre detected in response" in str(exc_info.value) - assert "Fiona detected in response" not in str(exc_info.value) # Wait for the trace to be saved # This is needed because the trace is saved asynchronously @@ -467,3 +464,124 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["guardrail-action"] == "log" ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if is_block_action: + if do_stream: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + else: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + else: + if do_stream: + _ = client.chat.completions.create( + **request, + stream=True, + ) + else: + _ = client.chat.completions.create( + **request, + stream=False, + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 1 if is_block_action else 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + )