diff --git a/client.py b/client.py new file mode 100644 index 0000000..759337d --- /dev/null +++ b/client.py @@ -0,0 +1,29 @@ +from openai import OpenAI +from httpx import Client +import os + +# unicode escape everything +guardrails = """ +raise "Rule 1: Do not talk about Fight Club" if: + (msg: Message) + "fight club" in msg.content +""".encode("unicode_escape") + +openai_client = OpenAI( + default_headers={ + "Invariant-Authorization": "Bearer " + os.getenv("INVARIANT_API_KEY"), + "Invariant-Guardrails": guardrails, + }, + base_url="http://localhost:8000/api/v1/gateway/non-streaming/openai", +) + +response = openai_client.chat.completions.create( + model="gpt-4", + messages=[ + { + "role": "user", + "content": "What can you tell me about fight club?", + } + ], +) +print("Response: ", response.choices[0].message.content) diff --git a/gateway/.env b/gateway/.env index 78d3391..c84df8b 100644 --- a/gateway/.env +++ b/gateway/.env @@ -1,4 +1,4 @@ POSTGRES_USER=postgres - POSTGRES_PASSWORD=postgres - POSTGRES_DB=invariantmonitor - POSTGRES_HOST=database \ No newline at end of file +POSTGRES_PASSWORD=postgres +POSTGRES_DB=invariantmonitor +POSTGRES_HOST=database \ No newline at end of file diff --git a/gateway/common/authorization.py b/gateway/common/authorization.py index 6a4e844..ac63467 100644 --- a/gateway/common/authorization.py +++ b/gateway/common/authorization.py @@ -8,8 +8,10 @@ API_KEYS_SEPARATOR = ";invariant-auth=" def extract_authorization_from_headers( - request: Request, dataset_name: Optional[str], llm_provider_api_key_header: str -) -> Tuple[str, str]: + request: Request, + dataset_name: Optional[str] = None, + llm_provider_api_key_header: Optional[str] = None, +) -> Tuple[Optional[str], Optional[str]]: """ Extracts the Invariant authorization and LLM Provider API key from the request headers. @@ -26,8 +28,15 @@ def extract_authorization_from_headers( The header in that case becomes: {llm_provider_api_key_header}: ";invariant-auth=" """ + # invariant api key invariant_authorization = request.headers.get(INVARIANT_AUTHORIZATION_HEADER) - llm_provider_api_key = request.headers.get(llm_provider_api_key_header) + # llm provider api key + if llm_provider_api_key_header is not None: + llm_provider_api_key = request.headers.get(llm_provider_api_key_header) + else: + llm_provider_api_key = None + + # if the dataset name is not None, we need to check if the invariant api key is present if dataset_name: if invariant_authorization is None: if llm_provider_api_key is None: @@ -43,9 +52,7 @@ def extract_authorization_from_headers( API_KEYS_SEPARATOR ) if len(api_keys) != 2 or not api_keys[1].strip(): - raise HTTPException( - status_code=400, detail="Invalid API Key format" - ) + raise HTTPException(status_code=400, detail="Invalid API Key format") invariant_authorization = f"Bearer {api_keys[1].strip()}" llm_provider_api_key = f"{api_keys[0].strip()}" diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 5978690..c785dfd 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -8,6 +8,9 @@ from typing import Optional import fastapi from httpx import HTTPStatusError +from common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet +from common.authorization import extract_authorization_from_headers + def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]: """ @@ -29,8 +32,8 @@ def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[ class GatewayConfig: """Common configurations for the Gateway Server.""" - def __init__(self, guardrails: Optional[str] = None): - self.guardrails = guardrails or self._load_guardrails_from_file() + def __init__(self): + self.guardrails = self._load_guardrails_from_file() def _load_guardrails_from_file(self) -> str: """ @@ -67,13 +70,7 @@ class GatewayConfig: raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e def __repr__(self) -> str: - return f"GatewayConfig(guardrails={repr(self.guardrails)})" - - def with_guardrails(self, guardrails: str) -> "GatewayConfig": - """ - Returns a new GatewayConfig instance with the specified guardrails. - """ - return GatewayConfig(guardrails) + return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})" class GatewayConfigManager: @@ -94,8 +91,20 @@ class GatewayConfigManager: local_config = GatewayConfig() cls._config_instance = local_config - # if provided in header, use custom guardrailing policy - if guardrail_file_contents := extract_policy_from_headers(request): - local_config = local_config.with_guardrails(guardrail_file_contents) - return local_config + + +async def GuardrailsInHeader(request: fastapi.Request) -> Optional[GuardrailRuleSet]: + # if provided in header, use custom guardrailing policy + if guardrails := extract_policy_from_headers(request): + return GuardrailRuleSet( + blocking_guardrails=[ + Guardrail( + id="guardrail-from-header", + name="guardrails from request header", + content=guardrails, + action=GuardrailAction.BLOCK, + ) + ], + logging_guardrails=[], + ) diff --git a/gateway/common/guardrails.py b/gateway/common/guardrails.py new file mode 100644 index 0000000..e4164c1 --- /dev/null +++ b/gateway/common/guardrails.py @@ -0,0 +1,31 @@ +"""Common guardrails data class.""" + +from enum import Enum +from typing import List + +from dataclasses import dataclass + + +class GuardrailAction(str, Enum): + """Enum representing the action to be taken for guardrail rules.""" + + BLOCK = "block" + LOG = "log" + + +@dataclass(frozen=True) +class Guardrail: + """Represents a single guardrail rule.""" + + id: str + name: str + content: str + action: GuardrailAction + + +@dataclass(frozen=True) +class GuardrailRuleSet: + """Grouped guardrail rules separated by their action.""" + + blocking_guardrails: List[Guardrail] + logging_guardrails: List[Guardrail] diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py new file mode 100644 index 0000000..0c4a31b --- /dev/null +++ b/gateway/common/request_context.py @@ -0,0 +1,93 @@ +"""Common Request context data class.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from common.config_manager import GatewayConfig +from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction + + +@dataclass(frozen=True) +class RequestContext: + """Structured context for a request. Must be created via `RequestContext.create()`.""" + + request_json: Dict[str, Any] + dataset_name: Optional[str] = None + invariant_authorization: Optional[str] = None + # the set of guardrails to enforce for this request + guardrails: Optional[GuardrailRuleSet] = None + config: Dict[str, Any] = None + + _created_via_factory: bool = field( + default=False, init=True, repr=False, compare=False + ) + + def __post_init__(self): + if not self._created_via_factory: + raise RuntimeError( + "RequestContext must be created using RequestContext.create()" + ) + + @classmethod + def create( + cls, + request_json: Dict[str, Any], + dataset_name: Optional[str] = None, + invariant_authorization: Optional[str] = None, + guardrails: Optional[GuardrailRuleSet] = None, + config: Optional[GatewayConfig] = None, + ) -> "RequestContext": + """Creates a new RequestContext instance, applying default guardrails if needed.""" + + # Convert GatewayConfig to a basic dict, excluding guardrails_from_file + context_config = { + key: value + for key, value in (config.__dict__.items() if config else {}) + if key != "guardrails_from_file" + } + + # If no guardrails are configured for the dataset on Explorer, + # and the config specifies guardrails_from_file, use that. + guardrails = guardrails + if ( + ( + not guardrails + or ( + not guardrails.blocking_guardrails + and not guardrails.logging_guardrails + ) + ) + and config + and config.guardrails_from_file + ): + # TODO: Support logging guardrails via file. + guardrails = GuardrailRuleSet( + blocking_guardrails=[ + Guardrail( + id="default", + name="default", + content=config.guardrails_from_file, + action=GuardrailAction.BLOCK, + ) + ], + logging_guardrails=[], + ) + + return cls( + request_json=request_json, + dataset_name=dataset_name, + invariant_authorization=invariant_authorization, + guardrails=guardrails, + config=context_config, + _created_via_factory=True, + ) + + def __repr__(self) -> str: + return ( + f"RequestContext(" + f"request_json={self.request_json}, " + f"dataset_name={self.dataset_name}, " + f"invariant_authorization={self.invariant_authorization}, " + f"guardrails={self.guardrails}, " + f"config={self.config})" + ) diff --git a/gateway/common/request_context_data.py b/gateway/common/request_context_data.py deleted file mode 100644 index 8ae98ba..0000000 --- a/gateway/common/request_context_data.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Common Request context data class.""" - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from common.config_manager import GatewayConfig - - -@dataclass(frozen=True) -class RequestContextData: - """Request context data class.""" - - request_json: Dict[str, Any] - dataset_name: Optional[str] = None - invariant_authorization: Optional[str] = None - config: Optional[GatewayConfig] = None diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index fd0760b..5fbb8d8 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -3,15 +3,18 @@ import os from typing import Any, Dict, List +from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse from invariant_sdk.types.annotations import AnnotationCreate +import httpx + DEFAULT_API_URL = "https://explorer.invariantlabs.ai" def create_annotations_from_guardrails_errors( - guardrails_errors: List[dict], + guardrails_errors: List[dict], action: str = "block" ) -> List[AnnotationCreate]: """Create Explorer annotations from the guardrails errors.""" annotations = [] @@ -45,7 +48,10 @@ def create_annotations_from_guardrails_errors( AnnotationCreate( content=content, address=r, - extra_metadata={"source": "guardrails-error"}, + extra_metadata={ + "source": "guardrails-error", + "guardrail-action": action, + }, ) ) return annotations @@ -91,3 +97,79 @@ async def push_trace( except Exception as e: print(f"Failed to push trace: {e}") return {"error": str(e)} + + +async def fetch_guardrails_from_explorer( + dataset_name: str, invariant_authorization: str +) -> GuardrailRuleSet: + """Get the guardrails for the dataset. + + Returns: + GuardrailRuleSet: The guardrails for the dataset grouped by their action. + """ + + # TODO: Implement a single API in explorer backend which can return + # dataset details without requiring a username. + + client = httpx.AsyncClient( + base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), + headers={ + "Authorization": invariant_authorization, + }, + ) + + # Get the user details. + user_info_response = await client.get("/api/v1/user/info") + if user_info_response.status_code != 200: + raise ValueError( + f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}" + ) + user_details = user_info_response.json() + username = user_details["username"] + + # Get the dataset policies. + policies_response = await client.get( + f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy" + ) + if policies_response.status_code != 200: + if policies_response.status_code == 404: + # If the dataset does not exist, return empty guardrails. + return GuardrailRuleSet( + blocking_guardrails=[], + logging_guardrails=[], + ) + raise ValueError( + f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}" + ) + policies_details = policies_response.json() + guardrails = policies_details.get("policies", []) + + blocking_guardrails = [] + logging_guardrails = [] + for g in guardrails: + action = g["action"] + + if not g["enabled"]: + # Skip guardrails that are not enabled. + continue + + if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG): + print("[Warning] Skipping unknown guardrail action: ", action) + continue + + guardrail = Guardrail( + id=g["id"], + name=g["name"], + content=g["content"], + action=GuardrailAction(action), + ) + + if action == GuardrailAction.BLOCK: + blocking_guardrails.append(guardrail) + else: + logging_guardrails.append(guardrail) + + return GuardrailRuleSet( + blocking_guardrails=blocking_guardrails, + logging_guardrails=logging_guardrails, + ) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index e3c236f..2785788 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -7,7 +7,8 @@ from typing import Any, Dict, List from functools import wraps import httpx -from common.request_context_data import RequestContextData +from common.guardrails import Guardrail +from common.request_context import RequestContext DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -81,21 +82,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None: result.raise_for_status() -async def preload_guardrails(context: "RequestContextData") -> None: +async def preload_guardrails(context: "RequestContext") -> None: """ Preloads the guardrails for faster checking later. Args: - context: RequestContextData object. + context: RequestContext object. """ - if not context.config or not context.config.guardrails: + if not context.guardrails: return try: - task = asyncio.create_task( - _preload(context.config.guardrails, context.invariant_authorization) - ) - asyncio.shield(task) + # Move these calls to a batch preload/validate API. + for blocking_guardrail in context.guardrails.blocking_guardrails: + task = asyncio.create_task( + _preload(blocking_guardrail.content, context.invariant_authorization) + ) + asyncio.shield(task) + for logging_guadrail in context.guardrails.logging_guardrails: + task = asyncio.create_task( + _preload(logging_guadrail.content, context.invariant_authorization) + ) + asyncio.shield(task) except Exception as e: print(f"Error scheduling preload_guardrails task: {e}") @@ -322,14 +330,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse): async def check_guardrails( - messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str + messages: List[Dict[str, Any]], + guardrails: List[Guardrail], + invariant_authorization: str, ) -> Dict[str, Any]: """ Checks guardrails on the list of messages. + This calls the batch check API of the Guardrails service. Args: messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. - guardrails (str): The guardrails to check against. + guardrails (List[Guardrail]): The guardrails to check against. invariant_authorization (str): Value of the invariant-authorization header. @@ -340,8 +351,11 @@ async def check_guardrails( url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") try: result = await client.post( - f"{url}/api/v1/policy/check", - json={"messages": messages, "policy": guardrails}, + f"{url}/api/v1/policy/check/batch", + json={ + "messages": messages, + "policies": [g.content for g in guardrails], + }, headers={ "Authorization": invariant_authorization, "Accept": "application/json", @@ -351,8 +365,20 @@ async def check_guardrails( raise Exception( f"Guardrails check failed: {result.status_code} - {result.text}" ) - print(f"Guardrail check response: {result.json()}") - return result.json() + guardrails_result = result.json() + + aggregated_errors = {"errors": []} + for res in guardrails_result.get("result", []): + aggregated_errors["errors"].extend(res.get("errors", [])) + + # check for any error_message + if error_message := res.get("error_message"): + return { + "errors": [ + {"args": [error_message], "kwargs": {}, "ranges": []} + ] + } + return aggregated_errors except Exception as e: print(f"Failed to verify guardrails: {e}") # make sure runtime errors are also visible in e.g. Explorer diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 2f3e243..4fd0744 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -5,20 +5,29 @@ import json from typing import Any, Optional import httpx -from regex import R -from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse + +from common.authorization import extract_authorization_from_headers +from common.config_manager import ( + GatewayConfig, + GatewayConfigManager, + GuardrailsInHeader, +) from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from common.guardrails import GuardrailAction, GuardrailRuleSet +from common.request_context import RequestContext from converters.anthropic_to_invariant import ( convert_anthropic_to_invariant_message_format, ) -from common.authorization import extract_authorization_from_headers -from common.request_context_data import RequestContextData +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -61,6 +70,7 @@ async def anthropic_v1_messages_gateway( request: Request, dataset_name: str = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument + header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader), ): """Proxy calls to the Anthropic APIs""" headers = { @@ -83,21 +93,26 @@ async def anthropic_v1_messages_gateway( data=request_body, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset from explorer. + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + guardrails=header_guardrails or dataset_guardrails, config=config, ) - asyncio.create_task(preload_guardrails(context)) - if request_json.get("stream"): return await handle_streaming_response(context, client, anthropic_request) return await handle_non_streaming_response(context, client, anthropic_request) def create_metadata( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = {k: v for k, v in context.request_json.items() if k != "messages"} @@ -108,7 +123,7 @@ def create_metadata( def combine_request_and_response_messages( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ): """Combine the request and response messages""" messages = [] @@ -117,42 +132,63 @@ def combine_request_and_response_messages( {"role": "system", "content": context.request_json.get("system")} ) messages.extend(context.request_json.get("messages", [])) - if len(json_response) > 0: - messages.append(json_response) + if len(response_json) > 0: + messages.append(response_json) return messages async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] + context: RequestContext, action: GuardrailAction, response_json: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" - messages = combine_request_and_response_messages(context, json_response) + # Determine which guardrails to apply based on the action + guardrails = ( + context.guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.guardrails.blocking_guardrails + ) + if not guardrails: + return {} + + messages = combine_request_and_response_messages(context, response_json) converted_messages = convert_anthropic_to_invariant_message_format(messages) # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=converted_messages, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result async def push_to_explorer( - context: RequestContextData, + context: RequestContext, merged_response: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} annotations = create_annotations_from_guardrails_errors( - guardrails_execution_result.get("errors", []) + guardrails_execution_result.get("errors", []), action="block" ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []), action="log" + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + # Combine the messages from the request body and Anthropic response messages = combine_request_and_response_messages(context, merged_response) - converted_messages = convert_anthropic_to_invariant_message_format(messages) + _ = await push_trace( dataset_name=context.dataset_name, messages=[converted_messages], @@ -163,30 +199,32 @@ async def push_to_explorer( class InstrumentedAnthropicResponse(InstrumentedResponse): + """Instrumented response for Anthropic API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ): super().__init__() - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.anthropic_request: httpx.Request = anthropic_request # response data self.response: Optional[httpx.Response] = None self.response_string: Optional[str] = None - self.json_response: Optional[dict[str, Any]] = None + self.response_json: Optional[dict[str, Any]] = None # guardrailing response (if any) self.guardrails_execution_result = {} async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -220,10 +258,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): ) async def request(self): + """Make the request to the Anthropic API.""" self.response = await self.client.send(self.anthropic_request) try: - json_response = self.response.json() + response_json = self.response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, @@ -232,11 +271,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): if self.response.status_code != 200: raise HTTPException( status_code=self.response.status_code, - detail=json_response.get("error", "Unknown error from Anthropic"), + detail=response_json.get("error", "Unknown error from Anthropic"), ) - self.json_response = json_response - self.response_string = json.dumps(json_response) + self.response_json = response_json + self.response_string = json.dumps(response_json) return self._make_response( content=self.response_string, @@ -261,13 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" # ensure the response data is available assert self.response is not None, "response is None" - assert self.json_response is not None, "json_response is None" + assert self.response_json is not None, "response_json is None" assert self.response_string is not None, "response_string is None" - if self.context.config and self.context.config.guardrails: + if self.context.guardrails: # Block on the guardrails check guardrails_execution_result = await get_guardrails_check_result( - self.context, self.json_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.response_json, ) if guardrails_execution_result.get("errors", []): guardrail_response_string = json.dumps( @@ -283,7 +324,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, guardrails_execution_result, ) ) @@ -300,13 +341,13 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): # Push to Explorer - don't block on its response asyncio.create_task( push_to_explorer( - self.context, self.json_response, guardrails_execution_result + self.context, self.response_json, guardrails_execution_result ) ) async def handle_non_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ) -> Response: @@ -320,17 +361,19 @@ async def handle_non_streaming_response( return await response.instrumented_request() -class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): +class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): + """Instrumented streaming response for Anthropic API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.anthropic_request: httpx.Request = anthropic_request @@ -342,9 +385,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): async def on_start(self): """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -392,6 +437,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): + """Process the chunk and update the merged_response""" decoded_chunk = chunk.decode().strip() if not decoded_chunk: return @@ -400,14 +446,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): process_chunk(decoded_chunk, self.merged_response) # on last stream chunk, run output guardrails - if ( - "event: message_stop" in decoded_chunk - and self.context.config - and self.context.config.guardrails - ): + if "event: message_stop" in decoded_chunk and self.context.guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -420,7 +464,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): } ) - # yield an extra error chunk (without preventing the original chunk to go through after, + # yield an extra error chunk (without preventing the original chunk + # to go through after, # so client gets the proper message_stop event still) return ExtraItem( value=f"event: error\ndata: {error_chunk}\n\n".encode() @@ -440,12 +485,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse): async def handle_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, anthropic_request: httpx.Request, ) -> StreamingResponse: """Handles streaming Anthropic responses""" - response = InstrumentedAnthropicStreamingResposne( + response = InstrumentedAnthropicStreamingResponse( context=context, client=client, anthropic_request=anthropic_request, diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 59d0874..b256a83 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -5,16 +5,27 @@ import json from typing import Any, Literal, Optional import httpx -from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response from fastapi.responses import StreamingResponse + +from common.authorization import extract_authorization_from_headers +from common.config_manager import ( + GatewayConfig, + GatewayConfigManager, + GuardrailsInHeader, +) from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from common.authorization import extract_authorization_from_headers -from common.request_context_data import RequestContextData +from common.guardrails import GuardrailAction, GuardrailRuleSet +from common.request_context import RequestContext from converters.gemini_to_invariant import convert_request, convert_response +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -23,8 +34,6 @@ from integrations.guardrails import ( preload_guardrails, check_guardrails, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardrails import check_guardrails, preload_guardrails gateway = APIRouter() @@ -43,6 +52,7 @@ async def gemini_generate_content_gateway( None, title="Response Format", description="Set to 'sse' for streaming" ), config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument + header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader), ) -> Response: """Proxy calls to the Gemini GenerateContent API""" if endpoint not in ["generateContent", "streamGenerateContent"]: @@ -76,14 +86,19 @@ async def gemini_generate_content_gateway( headers=headers, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + guardrails=header_guardrails or dataset_guardrails, config=config, ) - asyncio.create_task(preload_guardrails(context)) - if alt == "sse" or endpoint == "streamGenerateContent": return await stream_response( context, @@ -98,16 +113,18 @@ async def gemini_generate_content_gateway( class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): + """Instrumented streaming response for Gemini API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ): super().__init__() # request data - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.gemini_request: httpx.Request = gemini_request @@ -124,6 +141,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): location: Literal["request", "response"], guardrails_execution_result: dict[str, Any], ) -> dict: + """Create a refusal response for the given request or response""" return { "candidates": [ { @@ -157,10 +175,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): } async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -184,6 +205,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): ) async def event_generator(self): + """Event generator for streaming responses""" response = await self.client.send(self.gemini_request, stream=True) if response.status_code != 200: @@ -199,6 +221,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): yield chunk async def on_chunk(self, chunk): + """Processes each chunk of the streaming response""" chunk_text = chunk.decode().strip() if not chunk_text: return @@ -210,12 +233,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): if ( self.merged_response.get("candidates", []) and self.merged_response.get("candidates")[0].get("finishReason", "") - and self.context.config - and self.context.config.guardrails + and self.context.guardrails ): # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -254,7 +278,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): async def stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ) -> Response: @@ -269,7 +293,6 @@ async def stream_response( async def event_generator(): async for chunk in response.instrumented_event_generator(): yield chunk - print("chunk", chunk) return StreamingResponse( event_generator(), @@ -332,7 +355,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) -> def create_metadata( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, response_json: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = { @@ -352,32 +375,53 @@ def create_metadata( async def get_guardrails_check_result( - context: RequestContextData, response_json: dict[str, Any] + context: RequestContext, action: GuardrailAction, response_json: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" + # Determine which guardrails to apply based on the action + guardrails = ( + context.guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.guardrails.blocking_guardrails + ) + if not guardrails: + return {} + converted_requests = convert_request(context.request_json) converted_responses = convert_response(response_json) # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=converted_requests + converted_responses, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result async def push_to_explorer( - context: RequestContextData, + context: RequestContext, response_json: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} annotations = create_annotations_from_guardrails_errors( - guardrails_execution_result.get("errors", []) + guardrails_execution_result.get("errors", []), action="block" ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=response_json, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []), action="log" + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + converted_requests = convert_request(context.request_json) converted_responses = convert_response(response_json) @@ -391,16 +435,18 @@ async def push_to_explorer( class InstrumentedGeminiResponse(InstrumentedResponse): + """Instrumented response for Gemini API""" + def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ): super().__init__() # request data - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.gemini_request: httpx.Request = gemini_request @@ -412,10 +458,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse): self.guardrails_execution_result: Optional[dict[str, Any]] = None async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, {} + self.context, action=GuardrailAction.BLOCK, response_json={} ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -463,6 +512,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): ) async def request(self): + """Makes the request to the Gemini API and return the response""" self.response = await self.client.send(self.gemini_request) response_string = self.response.text @@ -489,13 +539,16 @@ class InstrumentedGeminiResponse(InstrumentedResponse): ) async def on_end(self): + """Runs when the request ends.""" response_string = json.dumps(self.response_json) response_code = self.response.status_code - if self.context.config and self.context.config.guardrails: + if self.context.guardrails: # Block on the guardrails check guardrails_execution_result = await get_guardrails_check_result( - self.context, self.response_json + self.context, + action=GuardrailAction.BLOCK, + response_json=self.response_json, ) if guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -539,7 +592,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): async def handle_non_streaming_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, gemini_request: httpx.Request, ) -> Response: diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 6ef3808..45d0b14 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -5,14 +5,26 @@ import json from typing import Any, Optional import httpx -from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import StreamingResponse + +from common.authorization import extract_authorization_from_headers +from common.config_manager import ( + GatewayConfig, + GatewayConfigManager, + GuardrailsInHeader, +) from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from common.guardrails import GuardrailAction, GuardrailRuleSet +from common.request_context import RequestContext +from integrations.explorer import ( + create_annotations_from_guardrails_errors, + fetch_guardrails_from_explorer, + push_trace, +) from integrations.guardrails import ( ExtraItem, InstrumentedResponse, @@ -20,8 +32,6 @@ from integrations.guardrails import ( check_guardrails, preload_guardrails, ) -from common.authorization import extract_authorization_from_headers -from common.request_context_data import RequestContextData gateway = APIRouter() @@ -48,6 +58,7 @@ async def openai_chat_completions_gateway( request: Request, dataset_name: str = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument + header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader), ) -> Response: """Proxy calls to the OpenAI APIs""" headers = { @@ -71,14 +82,19 @@ async def openai_chat_completions_gateway( headers=headers, ) - context = RequestContextData( + dataset_guardrails = None + if dataset_name: + # Get the guardrails for the dataset + dataset_guardrails = await fetch_guardrails_from_explorer( + dataset_name, invariant_authorization + ) + context = RequestContext.create( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + guardrails=header_guardrails or dataset_guardrails, config=config, ) - asyncio.create_task(preload_guardrails(context)) - if request_json.get("stream", False): return await handle_stream_response( context, @@ -91,19 +107,20 @@ async def openai_chat_completions_gateway( class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): """ - Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + Does a streaming OpenAI completion request at the core, but also checks guardrails + before (concurrent) and after the request. """ def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.open_ai_request: httpx.Request = open_ai_request @@ -130,10 +147,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.tool_call_mapping_by_index = {} async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" - if self.context.config and self.context.config.guardrails: + """ + Check guardrails in a pipelined fashion, before processing the first chunk + (for input guardrailing). + """ + if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -163,6 +185,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) async def on_chunk(self, chunk): + """Processes each chunk of the stream and checks guardrails at the end of the stream""" # process and check each chunk chunk_text = chunk.decode().strip() if not chunk_text: @@ -178,14 +201,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) # check guardrails at the end of the stream (on the '[DONE]' SSE chunk.) - if ( - "data: [DONE]" in chunk_text - and self.context.config - and self.context.config.guardrails - ): + if "data: [DONE]" in chunk_text and self.context.guardrails: # Block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.merged_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.merged_response, ) if self.guardrails_execution_result.get("errors", []): error_chunk = json.dumps( @@ -203,7 +224,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): # push will happen in on_end async def on_end(self): - """Sends full merged response to the exploree.""" + """Sends full merged response to the explorer.""" # don't block on the response from explorer (.create_task) if self.context.dataset_name: asyncio.create_task( @@ -213,10 +234,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): ) async def event_generator(self): - """ - Actual OpenAI stream response. - """ - + """Actual OpenAI stream response.""" response = await self.client.send(self.open_ai_request, stream=True) if response.status_code != 200: error_content = await response.aread() @@ -233,7 +251,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): async def handle_stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ) -> Response: @@ -388,7 +406,7 @@ def update_existing_choice_with_delta( def create_metadata( - context: RequestContextData, merged_response: dict[str, Any] + context: RequestContext, merged_response: dict[str, Any] ) -> dict[str, Any]: """Creates metadata for the trace""" metadata = { @@ -408,7 +426,7 @@ def create_metadata( async def push_to_explorer( - context: RequestContextData, + context: RequestContext, merged_response: dict[str, Any], guardrails_execution_result: Optional[dict] = None, ) -> None: @@ -417,12 +435,26 @@ async def push_to_explorer( # or if the guardrails check returned errors. guardrails_execution_result = guardrails_execution_result or {} guardrails_errors = guardrails_execution_result.get("errors", []) - if guardrails_errors or not ( + annotations = create_annotations_from_guardrails_errors( + guardrails_errors, action="block" + ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []), action="log" + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + + if annotations or not ( merged_response.get("choices") and merged_response["choices"][0].get("finish_reason") not in FINISH_REASON_TO_PUSH_TRACE ): - annotations = create_annotations_from_guardrails_errors(guardrails_errors) # Combine the messages from the request body and the choices from the OpenAI response messages = list(context.request_json.get("messages", [])) messages += [choice["message"] for choice in merged_response.get("choices", [])] @@ -436,18 +468,29 @@ async def push_to_explorer( async def get_guardrails_check_result( - context: RequestContextData, json_response: dict[str, Any] | None = None + context: RequestContext, + action: GuardrailAction, + response_json: dict[str, Any] | None = None, ) -> dict[str, Any]: """Get the guardrails check result""" - messages = list(context.request_json.get("messages", [])) + # Determine which guardrails to apply based on the action + guardrails = ( + context.guardrails.logging_guardrails + if action == GuardrailAction.LOG + else context.guardrails.blocking_guardrails + ) - if json_response is not None: - messages += [choice["message"] for choice in json_response.get("choices", [])] + if not guardrails: + return {} + + messages = list(context.request_json.get("messages", [])) + if response_json is not None: + messages += [choice["message"] for choice in response_json.get("choices", [])] # Block on the guardrails check guardrails_execution_result = await check_guardrails( messages=messages, - guardrails=context.config.guardrails, + guardrails=guardrails, invariant_authorization=context.invariant_authorization, ) return guardrails_execution_result @@ -455,35 +498,39 @@ async def get_guardrails_check_result( class InstrumentedOpenAIResponse(InstrumentedResponse): """ - Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request. + Does an OpenAI completion request at the core, but also checks guardrails + before (concurrent) and after the request. """ def __init__( self, - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ): super().__init__() # request parameters - self.context: RequestContextData = context + self.context: RequestContext = context self.client: httpx.AsyncClient = client self.open_ai_request: httpx.Request = open_ai_request # request outputs self.response: Optional[httpx.Response] = None - self.json_response: Optional[dict[str, Any]] = None + self.response_json: Optional[dict[str, Any]] = None # guardrailing output (if any) self.guardrails_execution_result: Optional[dict] = None async def on_start(self): - """Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)""" - if self.context.config and self.context.config.guardrails: + """ + Checks guardrails in a pipelined fashion, before processing + the first chunk (for input guardrailing) + """ + if self.context.guardrails: # block on the guardrails check self.guardrails_execution_result = await get_guardrails_check_result( - self.context + self.context, action=GuardrailAction.BLOCK ) if self.guardrails_execution_result.get("errors", []): # Push annotated trace to the explorer - don't block on its response @@ -516,7 +563,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.response = await self.client.send(self.open_ai_request) try: - self.json_response = self.response.json() + self.response_json = self.response.json() except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, @@ -525,10 +572,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): if self.response.status_code != 200: raise HTTPException( status_code=self.response.status_code, - detail=self.json_response.get("error", "Unknown error from OpenAI API"), + detail=self.response_json.get("error", "Unknown error from OpenAI API"), ) - response_string = json.dumps(self.json_response) + response_string = json.dumps(self.response_json) response_code = self.response.status_code return Response( @@ -541,23 +588,26 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): async def on_end(self): """Postprocesses the OpenAI response and potentially replace it with a guardrails error.""" - # these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed) + # these two request outputs are guaranteed to be available by the time we reach + # this point (after self.request() was executed) # nevertheless, we check for them to avoid any potential issues assert ( self.response is not None ), "on_end called before 'self.response' was available" assert ( - self.json_response is not None - ), "on_end called before 'self.json_response' was available" + self.response_json is not None + ), "on_end called before 'self.response_json' was available" # extract original response status code response_code = self.response.status_code # if we have guardrails, check the response - if self.context.config and self.context.config.guardrails: + if self.context.guardrails: # run guardrails again, this time on request + response self.guardrails_execution_result = await get_guardrails_check_result( - self.context, self.json_response + self.context, + action=GuardrailAction.BLOCK, + response_json=self.response_json, ) if self.guardrails_execution_result.get("errors", []): response_string = json.dumps( @@ -573,7 +623,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, self.guardrails_execution_result, ) ) @@ -592,7 +642,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): asyncio.create_task( push_to_explorer( self.context, - self.json_response, + self.response_json, # include any guardrailing errors if available self.guardrails_execution_result, ) @@ -600,7 +650,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): async def handle_non_stream_response( - context: RequestContextData, + context: RequestContext, client: httpx.AsyncClient, open_ai_request: httpx.Request, ) -> Response: diff --git a/run.sh b/run.sh index f666f96..251f68d 100755 --- a/run.sh +++ b/run.sh @@ -93,7 +93,12 @@ integration_tests() { fi echo "File successfully downloaded: $FILE" - TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py" + if [[ -z "$INVARIANT_API_KEY" ]]; then + echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests." + exit 1 + fi + + TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/integration_test_guardrails_via_file.py" if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then TEST_GUARDRAILS_FILE_PATH=$(realpath "$TEST_GUARDRAILS_FILE_PATH") diff --git a/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py b/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py index a2e3db0..1597846 100644 --- a/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/integration/anthropic/test_anthropic_header_with_invariant_key.py @@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header( """Test the Anthropic gateway with Invariant key in the Anthropic key""" anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - { - "ANTHROPIC_API_KEY": anthropic_api_key - + ";invariant-auth=" - }, + {"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix}, ): client = anthropic.Anthropic( http_client=Client(), diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 89591ea..abc3937 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -12,10 +12,11 @@ from typing import Dict, List # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_anthropic_client + import anthropic import pytest import requests -from httpx import Client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -26,14 +27,8 @@ class WeatherAgent: def __init__(self, gateway_url, push_to_explorer): self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") - self.client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{gateway_url}/api/v1/gateway/{self.dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + self.client = get_anthropic_client( + gateway_url, push_to_explorer, self.dataset_name ) self.get_weather_function = { "name": "get_weather", diff --git a/tests/integration/anthropic/test_anthropic_without_tool_call.py b/tests/integration/anthropic/test_anthropic_without_tool_call.py index a397c25..280341f 100644 --- a/tests/integration/anthropic/test_anthropic_without_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_without_tool_call.py @@ -8,10 +8,10 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -import anthropic +from utils import get_anthropic_client + import pytest import requests -from httpx import Client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -26,15 +26,10 @@ async def test_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") - - client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) cities = ["zurich", "new york", "london"] @@ -91,16 +86,7 @@ async def test_streaming_response_without_tool_call( ): """Test the Anthropic gateway without tool calling.""" dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") - - client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", - ) + client = get_anthropic_client(gateway_url, push_to_explorer, dataset_name) cities = ["zurich", "new york", "london"] queries = [ diff --git a/tests/integration/docker-compose.test.yml b/tests/integration/docker-compose.test.yml index 52d481e..8c1afe3 100644 --- a/tests/integration/docker-compose.test.yml +++ b/tests/integration/docker-compose.test.yml @@ -60,6 +60,7 @@ services: app-api: container_name: invariant-gateway-test-explorer-app-api image: ghcr.io/invariantlabs-ai/explorer/app-api:latest + pull_policy: always platform: linux/amd64 depends_on: database: diff --git a/tests/integration/gemini/test_generate_content_with_tool_calls.py b/tests/integration/gemini/test_generate_content_with_tool_calls.py index bcf838e..96c4bd7 100644 --- a/tests/integration/gemini/test_generate_content_with_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_with_tool_calls.py @@ -8,9 +8,10 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_gemini_client + import pytest import requests -from google import genai from google.genai import types # Pytest plugins @@ -143,18 +144,7 @@ async def test_generate_content_with_tool_call( without streaming. """ dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - "headers": { - "invariant-authorization": "Bearer " - }, # This key is not used for local tests - }, - ) + client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) request = { "model": "gemini-2.0-flash", diff --git a/tests/integration/gemini/test_generate_content_without_tool_calls.py b/tests/integration/gemini/test_generate_content_without_tool_calls.py index d16797a..27036a2 100644 --- a/tests/integration/gemini/test_generate_content_without_tool_calls.py +++ b/tests/integration/gemini/test_generate_content_without_tool_calls.py @@ -10,6 +10,8 @@ from unittest.mock import patch # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_gemini_client + import pytest import PIL.Image import requests @@ -29,17 +31,8 @@ async def test_generate_content( ): """Test the generate content gateway calls without tool calling.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - "headers": { - "invariant-authorization": "Bearer " - }, # This key is not used for local tests - }, - ) + client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) + request = { "model": "gemini-2.0-flash", "contents": "What is the capital of France?", @@ -115,18 +108,8 @@ async def test_generate_content_with_image( ): """Test that generate content gateway calls work with image.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client(gateway_url, push_to_explorer, dataset_name) - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - "headers": { - "invariant-authorization": "Bearer " - }, # This key is not used for local tests - }, - ) image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png" image = PIL.Image.open(image_path) @@ -181,9 +164,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( """Test the generate content gateway calls with the Invariant API Key in the Gemini Key header.""" dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" gemini_api_key = os.getenv("GEMINI_API_KEY") + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - {"GEMINI_API_KEY": gemini_api_key + ";invariant-auth="}, + {"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix}, ): client = genai.Client( api_key=os.getenv("GEMINI_API_KEY"), @@ -194,14 +178,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( chat_response = client.models.generate_content( model="gemini-2.0-flash", - contents="What is the capital of Spain?", + contents="What is the capital of Denmark?", config={ "maxOutputTokens": 100, }, ) # Verify the chat response - assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper() + assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper() expected_assistant_message = chat_response.candidates[0].content.parts[0].text # Wait for the trace to be saved @@ -228,7 +212,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header( assert trace["messages"] == [ { "role": "user", - "content": [{"text": "What is the capital of Spain?", "type": "text"}], + "content": [{"text": "What is the capital of Denmark?", "type": "text"}], }, { "role": "assistant", diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index 173e5ab..035a845 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -8,10 +8,11 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset + import pytest import requests -from httpx import Client -from anthropic import Anthropic, APIStatusError, BadRequestError +from anthropic import APIStatusError, BadRequestError # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -32,16 +33,10 @@ async def test_message_content_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - - client = Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -161,16 +156,10 @@ async def test_tool_call_guardrail_from_file( } dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - - client = Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) if not do_stream: @@ -255,16 +244,10 @@ async def test_input_from_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" - - client = Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/anthropic", + client = get_anthropic_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -332,3 +315,276 @@ async def test_input_from_guardrail_from_file( == "Users must not mention the magic phrase 'Fight Club'" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("do_stream", [True, False]) +async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + client = get_anthropic_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content', + action="log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Ask about the capital of Spain + # This should not be blocked by the guardrails from the explorer when we push to explorer + # because the file based guardrails are overridden by the explorer guardrails + spain_request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [{"role": "user", "content": "What is the capital of Spain?"}], + "max_tokens": 100, + } + if not do_stream: + chat_response = client.messages.create( + **spain_request, + stream=False, + ) + + assert "Madrid" in chat_response.content[0].text + else: + chat_response = client.messages.create( + **spain_request, + stream=True, + ) + + merged_content = "" + for chunk in chat_response: + if chunk.type == "content_block_delta": + merged_content += chunk.delta.text + assert "Madrid" in merged_content + + # Ask about Shrek + # This should be blocked by the guardrails from the explorer + user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details." + shrek_request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.messages.create( + **shrek_request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + else: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create( + **shrek_request, + stream=True, + ) + + for _ in chat_response: + pass + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 2 + trace_id = traces[1]["id"] + + # Fetch the second trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 2 + assert ( + annotations[0]["content"] == "ogre detected in response" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + ) + assert ( + annotations[1]["content"] == "Fiona detected in response" + and annotations[1]["extra_metadata"]["source"] == "guardrails-error" + and annotations[1]["extra_metadata"]["guardrail-action"] == "log" + ) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + client = get_anthropic_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if is_block_action: + if do_stream: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create( + **request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + else: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.messages.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + + else: + if do_stream: + _ = client.messages.create( + **request, + stream=True, + ) + else: + _ = client.messages.create( + **request, + stream=False, + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 if not is_block_action else 1 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index c463186..b3ac35e 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -8,9 +8,10 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset + import pytest import requests -from httpx import Client from google import genai # Pytest plugins @@ -30,17 +31,10 @@ async def test_message_content_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - }, + client = get_gemini_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -141,17 +135,10 @@ async def test_tool_call_guardrail_from_file( ) dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - }, + client = get_gemini_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -244,17 +231,10 @@ async def test_input_from_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" - - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), - http_options={ - "headers": { - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/gemini", - }, + client = get_gemini_client( + gateway_url, + push_to_explorer, + dataset_name, ) request = { @@ -323,6 +303,259 @@ async def test_input_from_guardrail_from_file( ) +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize("do_stream", [True, False]) +async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content', + action="log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Ask about the capital of Spain + # This should not be blocked by the guardrails from the explorer when we push to explorer + # because the file based guardrails are overridden by the explorer guardrails + spain_request = { + "model": "gemini-2.0-flash", + "contents": "What is the capital of Spain?", + "config": { + "maxOutputTokens": 100, + }, + } + if not do_stream: + chat_response = client.models.generate_content(**spain_request) + + assert "Madrid" in chat_response.candidates[0].content.parts[0].text + else: + chat_response = client.models.generate_content_stream(**spain_request) + + merged_content = "" + for chunk in chat_response: + if ( + chunk.candidates + and chunk.candidates[0].content + and chunk.candidates[0].content.parts + ): + for text_part in chunk.candidates[0].content.parts: + merged_content += text_part.text + assert "Madrid" in merged_content + + # Ask about Shrek + # This should be blocked by the guardrails from the explorer + user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details." + shrek_request = { + "model": "gemini-2.0-flash", + "contents": user_prompt, + "config": { + "maxOutputTokens": 100, + }, + } + if not do_stream: + with pytest.raises(genai.errors.ClientError) as exc_info: + client.models.generate_content(**shrek_request) + + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + else: + response = client.models.generate_content_stream(**shrek_request) + + assert_is_streamed_refusal( + response, + [ + "[Invariant] The response did not pass the guardrails", + "ogre detected in response", + ], + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 2 + trace_id = traces[1]["id"] + + # Fetch the second trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 + assert trace["messages"][0] == { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt, + } + ], + } + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 2 + assert ( + annotations[0]["content"] == "ogre detected in response" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + ) + assert ( + annotations[1]["content"] == "Fiona detected in response" + and annotations[1]["extra_metadata"]["source"] == "guardrails-error" + and annotations[1]["extra_metadata"]["guardrail-action"] == "log" + ) + + +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "gemini-2.0-flash", + "contents": user_prompt, + "config": { + "maxOutputTokens": 100, + }, + } + if is_block_action: + if do_stream: + chat_response = client.models.generate_content_stream(**request) + + assert_is_streamed_refusal( + chat_response, + [ + "[Invariant] The request did not pass the guardrails", + "pun detected in user message", + ], + ) + else: + with pytest.raises(genai.errors.ClientError) as exc_info: + chat_response = client.models.generate_content(**request) + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + else: + if do_stream: + response = client.models.generate_content_stream(**request) + for _ in response: + pass + else: + _ = client.models.generate_content(**request) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 if not is_block_action else 1 + assert trace["messages"][0] == { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt, + } + ], + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) + + def is_refusal(chunk): return ( len(chunk.candidates) == 1 diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index acc2f67..6031778 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -8,10 +8,11 @@ import time # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset + import pytest import requests -from httpx import Client -from openai import OpenAI, BadRequestError, APIError +from openai import BadRequestError, APIError # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -30,17 +31,7 @@ async def test_message_content_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) request = { "model": "gpt-4o", @@ -161,17 +152,7 @@ async def test_tool_call_guardrail_from_file( } dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) if not do_stream: with pytest.raises(BadRequestError) as exc_info: @@ -259,17 +240,7 @@ async def test_input_from_guardrail_from_file( pytest.fail("No INVARIANT_API_KEY set, failing") dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" - }, - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) request = { "model": "gpt-4o", @@ -349,3 +320,268 @@ async def test_input_from_guardrail_from_file( == "Users must not mention the magic phrase 'Fight Club'" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize("do_stream", [True, False]) +async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content', + action="log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Ask about the capital of Spain + # This should not be blocked by the guardrails from the explorer when we push to explorer + # because the file based guardrails are overridden by the explorer guardrails + spain_request = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "What is the capital of Spain?"}], + "max_tokens": 100, + } + if not do_stream: + chat_response = client.chat.completions.create( + **spain_request, + stream=False, + ) + + assert "Madrid" in chat_response.choices[0].message.content + else: + chat_response = client.chat.completions.create( + **spain_request, + stream=True, + ) + + merged_content = "" + for chunk in chat_response: + if chunk.choices[0].delta.content: + merged_content += chunk.choices[0].delta.content + assert "Madrid" in merged_content + + # Ask about Shrek + # This should be blocked by the guardrails from the explorer + user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details." + shrek_request = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **shrek_request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + # Only the block guardrail should be triggered here + assert "ogre detected in response" in str(exc_info.value) + assert "Fiona detected in response" not in str(exc_info.value) + else: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **shrek_request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The response did not pass the guardrails" in str( + exc_info.value + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 2 + trace_id = traces[1]["id"] + + # Fetch the second trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 2 + assert ( + annotations[0]["content"] == "ogre detected in response" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + ) + assert ( + annotations[1]["content"] == "Fiona detected in response" + and annotations[1]["extra_metadata"]["source"] == "guardrails-error" + and annotations[1]["extra_metadata"]["guardrail-action"] == "log" + ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if is_block_action: + if do_stream: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + else: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + else: + if do_stream: + _ = client.chat.completions.create( + **request, + stream=True, + ) + else: + _ = client.chat.completions.create( + **request, + stream=False, + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 1 if is_block_action else 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) diff --git a/tests/integration/guardrails/test_header_guardrails.py b/tests/integration/guardrails/test_header_guardrails.py index d5d6231..e03847d 100644 --- a/tests/integration/guardrails/test_header_guardrails.py +++ b/tests/integration/guardrails/test_header_guardrails.py @@ -1,4 +1,4 @@ -"""Test the guardrails from file with the OpenAI route.""" +"""Test the guardrails from header with the OpenAI route.""" import os import sys @@ -136,9 +136,7 @@ raise "Users must not mention the magic phrase 'Abracadabra'" if: "do_stream, push_to_explorer", [(True, True), (True, False), (False, True), (False, False)], ) -async def test_invalid_guardrail_in_header( - explorer_api_url, gateway_url, do_stream, push_to_explorer -): +async def test_invalid_guardrail_in_header(gateway_url, do_stream, push_to_explorer): """Test the message content guardrail.""" if not os.getenv("INVARIANT_API_KEY"): pytest.fail("No INVARIANT_API_KEY set, failing") @@ -178,7 +176,8 @@ raise "Users must not mention the magic phrase 'Abracadabra'" if: stream=False, ) - assert "Gateway: Guardrails check failed" in str( + print(exc_info.value.message, flush=True) + assert "Failed to create policy from policy source." in str( exc_info.value ), "guardrails check fails because of an invalid guardrailing rule" assert "illegal statement" in str( diff --git a/tests/integration/open_ai/test_chat_with_tool_call.py b/tests/integration/open_ai/test_chat_with_tool_call.py index 69ad6af..ba928c6 100644 --- a/tests/integration/open_ai/test_chat_with_tool_call.py +++ b/tests/integration/open_ai/test_chat_with_tool_call.py @@ -9,10 +9,10 @@ import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_open_ai_client + import pytest import requests -from httpx import Client -from openai import OpenAI # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -28,17 +28,7 @@ async def test_chat_completion_with_tool_call_without_streaming( without streaming. """ dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": "Bearer " - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) chat_response = client.chat.completions.create( model="gpt-4o", @@ -146,17 +136,7 @@ async def test_chat_completion_with_tool_call_with_streaming( while streaming. """ dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": "Bearer " - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) chat_response = client.chat.completions.create( model="gpt-4o", diff --git a/tests/integration/open_ai/test_chat_without_tool_calls.py b/tests/integration/open_ai/test_chat_without_tool_calls.py index e9134b3..69d7816 100644 --- a/tests/integration/open_ai/test_chat_without_tool_calls.py +++ b/tests/integration/open_ai/test_chat_without_tool_calls.py @@ -11,6 +11,8 @@ from unittest.mock import patch # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils import get_open_ai_client + import pytest import requests from httpx import Client @@ -30,17 +32,7 @@ async def test_chat_completion( ): """Test the chat completions gateway calls without tool calling.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" - - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": "Bearer " - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) chat_response = client.chat.completions.create( model="gpt-4o", @@ -103,17 +95,8 @@ async def test_chat_completion_with_image( ): """Test the chat completions gateway works with image.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name) - client = OpenAI( - http_client=Client( - headers={ - "Invariant-Authorization": "Bearer " - }, # This key is not used for local tests - ), - base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway/openai", - ) image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png" with image_path.open("rb") as image_file: base64_image = base64.b64encode(image_file.read()).decode("utf-8") @@ -189,9 +172,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header( """Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header.""" dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" openai_api_key = os.getenv("OPENAI_API_KEY") + invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}" with patch.dict( os.environ, - {"OPENAI_API_KEY": openai_api_key + ";invariant-auth="}, + {"OPENAI_API_KEY": openai_api_key + invariant_key_suffix}, ): client = OpenAI( http_client=Client(), @@ -252,7 +236,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream): client = OpenAI( http_client=Client( headers={ - "Invariant-Authorization": "Bearer " + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" }, # This key is not used for local tests ), base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai", diff --git a/tests/integration/resources/guardrails/find_capital_guardrails.py b/tests/integration/resources/guardrails/integration_test_guardrails_via_file.py similarity index 87% rename from tests/integration/resources/guardrails/find_capital_guardrails.py rename to tests/integration/resources/guardrails/integration_test_guardrails_via_file.py index 282720f..fe31990 100644 --- a/tests/integration/resources/guardrails/find_capital_guardrails.py +++ b/tests/integration/resources/guardrails/integration_test_guardrails_via_file.py @@ -1,15 +1,15 @@ -# For Anthropic and OpenAI tests: +# For Anthropic and OpenAI tests raise "Madrid detected in the response" if: (msg: Message) "Madrid" in msg.content -# For Gemini tests: +# For Gemini tests raise "Dublin detected in the response" if: (msg: Message) (chunk: str) in text(msg.content) "Dublin" in chunk -# For all provider tests: +# For all provider tests raise "get_capital is called with Germany as argument" if: (call: ToolCall) call is tool:get_capital diff --git a/tests/integration/utils.py b/tests/integration/utils.py new file mode 100644 index 0000000..6df9ce9 --- /dev/null +++ b/tests/integration/utils.py @@ -0,0 +1,105 @@ +"""Common utilities for integration tests.""" + +import os +import uuid +from typing import Any, Dict, Literal, Optional + +from httpx import AsyncClient, Client +from openai import OpenAI +from google import genai +from anthropic import Anthropic + + +def get_open_ai_client( + gateway_url: str, push_to_explorer: bool, dataset_name: str +) -> OpenAI: + """Create an OpenAI client for integration tests.""" + return OpenAI( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/openai", + ) + + +def get_anthropic_client( + gateway_url: str, push_to_explorer: bool, dataset_name: str +) -> Anthropic: + """Create an Anthropic client for integration tests.""" + return Anthropic( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/anthropic", + ) + + +def get_gemini_client( + gateway_url: str, push_to_explorer: bool, dataset_name: str +) -> genai.Client: + """Create a Gemini client for integration tests.""" + return genai.Client( + api_key=os.getenv("GEMINI_API_KEY"), + http_options={ + "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/gemini", + "headers": { + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + }, + ) + + +async def create_dataset( + explorer_api_url: str, + invariant_authorization: str, + dataset_name: Optional[str] = None, +) -> Dict[str, Any]: + """Create a dataset in the Explorer API.""" + client = Client(base_url=explorer_api_url) + response = client.post( + "/api/v1/dataset/create", + json={"name": dataset_name if dataset_name else f"test-dataset-{uuid.uuid4()}"}, + headers={"Authorization": invariant_authorization}, + timeout=5, + ) + if response.status_code != 200: + raise ValueError( + f"Failed to create dataset: {response.status_code}, {response.text}" + ) + return response.json() + + +async def add_guardrail_to_dataset( + explorer_api_url: str, + dataset_id: str, + policy: str, + action: Literal["block", "log"], + invariant_authorization: str, +) -> Dict[str, Any]: + """Add a guardrail to a dataset.""" + client = Client(base_url=explorer_api_url) + response = client.post( + f"/api/v1/dataset/{dataset_id}/policy", + json={ + "action": action, + "policy": policy, + "name": f"test-guardrail-{uuid.uuid4()}", + }, + headers={"Authorization": invariant_authorization}, + timeout=5, + ) + if response.status_code != 200: + raise ValueError( + f"Failed to add guardrail: {response.status_code}, {response.text}" + ) + return response.json()