mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-22 06:56:49 +02:00
Merge branch 'main' into guardrails-from-header
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
from openai import OpenAI
|
||||
from httpx import Client
|
||||
import os
|
||||
|
||||
# unicode escape everything
|
||||
guardrails = """
|
||||
raise "Rule 1: Do not talk about Fight Club" if:
|
||||
(msg: Message)
|
||||
"fight club" in msg.content
|
||||
""".encode("unicode_escape")
|
||||
|
||||
openai_client = OpenAI(
|
||||
default_headers={
|
||||
"Invariant-Authorization": "Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
"Invariant-Guardrails": guardrails,
|
||||
},
|
||||
base_url="http://localhost:8000/api/v1/gateway/non-streaming/openai",
|
||||
)
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What can you tell me about fight club?",
|
||||
}
|
||||
],
|
||||
)
|
||||
print("Response: ", response.choices[0].message.content)
|
||||
+3
-3
@@ -1,4 +1,4 @@
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=invariantmonitor
|
||||
POSTGRES_HOST=database
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=invariantmonitor
|
||||
POSTGRES_HOST=database
|
||||
@@ -8,8 +8,10 @@ API_KEYS_SEPARATOR = ";invariant-auth="
|
||||
|
||||
|
||||
def extract_authorization_from_headers(
|
||||
request: Request, dataset_name: Optional[str], llm_provider_api_key_header: str
|
||||
) -> Tuple[str, str]:
|
||||
request: Request,
|
||||
dataset_name: Optional[str] = None,
|
||||
llm_provider_api_key_header: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extracts the Invariant authorization and LLM Provider API key from the request headers.
|
||||
|
||||
@@ -26,8 +28,15 @@ def extract_authorization_from_headers(
|
||||
The header in that case becomes:
|
||||
{llm_provider_api_key_header}: "<API Key>;invariant-auth=<Invariant API Key>"
|
||||
"""
|
||||
# invariant api key
|
||||
invariant_authorization = request.headers.get(INVARIANT_AUTHORIZATION_HEADER)
|
||||
llm_provider_api_key = request.headers.get(llm_provider_api_key_header)
|
||||
# llm provider api key
|
||||
if llm_provider_api_key_header is not None:
|
||||
llm_provider_api_key = request.headers.get(llm_provider_api_key_header)
|
||||
else:
|
||||
llm_provider_api_key = None
|
||||
|
||||
# if the dataset name is not None, we need to check if the invariant api key is present
|
||||
if dataset_name:
|
||||
if invariant_authorization is None:
|
||||
if llm_provider_api_key is None:
|
||||
@@ -43,9 +52,7 @@ def extract_authorization_from_headers(
|
||||
API_KEYS_SEPARATOR
|
||||
)
|
||||
if len(api_keys) != 2 or not api_keys[1].strip():
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid API Key format"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid API Key format")
|
||||
|
||||
invariant_authorization = f"Bearer {api_keys[1].strip()}"
|
||||
llm_provider_api_key = f"{api_keys[0].strip()}"
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Optional
|
||||
import fastapi
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
from common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
|
||||
|
||||
def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]:
|
||||
"""
|
||||
@@ -29,8 +32,8 @@ def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[
|
||||
class GatewayConfig:
|
||||
"""Common configurations for the Gateway Server."""
|
||||
|
||||
def __init__(self, guardrails: Optional[str] = None):
|
||||
self.guardrails = guardrails or self._load_guardrails_from_file()
|
||||
def __init__(self):
|
||||
self.guardrails = self._load_guardrails_from_file()
|
||||
|
||||
def _load_guardrails_from_file(self) -> str:
|
||||
"""
|
||||
@@ -67,13 +70,7 @@ class GatewayConfig:
|
||||
raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"GatewayConfig(guardrails={repr(self.guardrails)})"
|
||||
|
||||
def with_guardrails(self, guardrails: str) -> "GatewayConfig":
|
||||
"""
|
||||
Returns a new GatewayConfig instance with the specified guardrails.
|
||||
"""
|
||||
return GatewayConfig(guardrails)
|
||||
return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})"
|
||||
|
||||
|
||||
class GatewayConfigManager:
|
||||
@@ -94,8 +91,20 @@ class GatewayConfigManager:
|
||||
local_config = GatewayConfig()
|
||||
cls._config_instance = local_config
|
||||
|
||||
# if provided in header, use custom guardrailing policy
|
||||
if guardrail_file_contents := extract_policy_from_headers(request):
|
||||
local_config = local_config.with_guardrails(guardrail_file_contents)
|
||||
|
||||
return local_config
|
||||
|
||||
|
||||
async def GuardrailsInHeader(request: fastapi.Request) -> Optional[GuardrailRuleSet]:
|
||||
# if provided in header, use custom guardrailing policy
|
||||
if guardrails := extract_policy_from_headers(request):
|
||||
return GuardrailRuleSet(
|
||||
blocking_guardrails=[
|
||||
Guardrail(
|
||||
id="guardrail-from-header",
|
||||
name="guardrails from request header",
|
||||
content=guardrails,
|
||||
action=GuardrailAction.BLOCK,
|
||||
)
|
||||
],
|
||||
logging_guardrails=[],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Common guardrails data class."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class GuardrailAction(str, Enum):
|
||||
"""Enum representing the action to be taken for guardrail rules."""
|
||||
|
||||
BLOCK = "block"
|
||||
LOG = "log"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Guardrail:
|
||||
"""Represents a single guardrail rule."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
content: str
|
||||
action: GuardrailAction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuardrailRuleSet:
|
||||
"""Grouped guardrail rules separated by their action."""
|
||||
|
||||
blocking_guardrails: List[Guardrail]
|
||||
logging_guardrails: List[Guardrail]
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Common Request context data class."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from common.config_manager import GatewayConfig
|
||||
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContext:
|
||||
"""Structured context for a request. Must be created via `RequestContext.create()`."""
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
invariant_authorization: Optional[str] = None
|
||||
# the set of guardrails to enforce for this request
|
||||
guardrails: Optional[GuardrailRuleSet] = None
|
||||
config: Dict[str, Any] = None
|
||||
|
||||
_created_via_factory: bool = field(
|
||||
default=False, init=True, repr=False, compare=False
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self._created_via_factory:
|
||||
raise RuntimeError(
|
||||
"RequestContext must be created using RequestContext.create()"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
request_json: Dict[str, Any],
|
||||
dataset_name: Optional[str] = None,
|
||||
invariant_authorization: Optional[str] = None,
|
||||
guardrails: Optional[GuardrailRuleSet] = None,
|
||||
config: Optional[GatewayConfig] = None,
|
||||
) -> "RequestContext":
|
||||
"""Creates a new RequestContext instance, applying default guardrails if needed."""
|
||||
|
||||
# Convert GatewayConfig to a basic dict, excluding guardrails_from_file
|
||||
context_config = {
|
||||
key: value
|
||||
for key, value in (config.__dict__.items() if config else {})
|
||||
if key != "guardrails_from_file"
|
||||
}
|
||||
|
||||
# If no guardrails are configured for the dataset on Explorer,
|
||||
# and the config specifies guardrails_from_file, use that.
|
||||
guardrails = guardrails
|
||||
if (
|
||||
(
|
||||
not guardrails
|
||||
or (
|
||||
not guardrails.blocking_guardrails
|
||||
and not guardrails.logging_guardrails
|
||||
)
|
||||
)
|
||||
and config
|
||||
and config.guardrails_from_file
|
||||
):
|
||||
# TODO: Support logging guardrails via file.
|
||||
guardrails = GuardrailRuleSet(
|
||||
blocking_guardrails=[
|
||||
Guardrail(
|
||||
id="default",
|
||||
name="default",
|
||||
content=config.guardrails_from_file,
|
||||
action=GuardrailAction.BLOCK,
|
||||
)
|
||||
],
|
||||
logging_guardrails=[],
|
||||
)
|
||||
|
||||
return cls(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=guardrails,
|
||||
config=context_config,
|
||||
_created_via_factory=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RequestContext("
|
||||
f"request_json={self.request_json}, "
|
||||
f"dataset_name={self.dataset_name}, "
|
||||
f"invariant_authorization={self.invariant_authorization}, "
|
||||
f"guardrails={self.guardrails}, "
|
||||
f"config={self.config})"
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Common Request context data class."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from common.config_manager import GatewayConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContextData:
|
||||
"""Request context data class."""
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
invariant_authorization: Optional[str] = None
|
||||
config: Optional[GatewayConfig] = None
|
||||
@@ -3,15 +3,18 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
||||
from invariant_sdk.types.annotations import AnnotationCreate
|
||||
|
||||
import httpx
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
def create_annotations_from_guardrails_errors(
|
||||
guardrails_errors: List[dict],
|
||||
guardrails_errors: List[dict], action: str = "block"
|
||||
) -> List[AnnotationCreate]:
|
||||
"""Create Explorer annotations from the guardrails errors."""
|
||||
annotations = []
|
||||
@@ -45,7 +48,10 @@ def create_annotations_from_guardrails_errors(
|
||||
AnnotationCreate(
|
||||
content=content,
|
||||
address=r,
|
||||
extra_metadata={"source": "guardrails-error"},
|
||||
extra_metadata={
|
||||
"source": "guardrails-error",
|
||||
"guardrail-action": action,
|
||||
},
|
||||
)
|
||||
)
|
||||
return annotations
|
||||
@@ -91,3 +97,79 @@ async def push_trace(
|
||||
except Exception as e:
|
||||
print(f"Failed to push trace: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def fetch_guardrails_from_explorer(
|
||||
dataset_name: str, invariant_authorization: str
|
||||
) -> GuardrailRuleSet:
|
||||
"""Get the guardrails for the dataset.
|
||||
|
||||
Returns:
|
||||
GuardrailRuleSet: The guardrails for the dataset grouped by their action.
|
||||
"""
|
||||
|
||||
# TODO: Implement a single API in explorer backend which can return
|
||||
# dataset details without requiring a username.
|
||||
|
||||
client = httpx.AsyncClient(
|
||||
base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
},
|
||||
)
|
||||
|
||||
# Get the user details.
|
||||
user_info_response = await client.get("/api/v1/user/info")
|
||||
if user_info_response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
|
||||
)
|
||||
user_details = user_info_response.json()
|
||||
username = user_details["username"]
|
||||
|
||||
# Get the dataset policies.
|
||||
policies_response = await client.get(
|
||||
f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy"
|
||||
)
|
||||
if policies_response.status_code != 200:
|
||||
if policies_response.status_code == 404:
|
||||
# If the dataset does not exist, return empty guardrails.
|
||||
return GuardrailRuleSet(
|
||||
blocking_guardrails=[],
|
||||
logging_guardrails=[],
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}"
|
||||
)
|
||||
policies_details = policies_response.json()
|
||||
guardrails = policies_details.get("policies", [])
|
||||
|
||||
blocking_guardrails = []
|
||||
logging_guardrails = []
|
||||
for g in guardrails:
|
||||
action = g["action"]
|
||||
|
||||
if not g["enabled"]:
|
||||
# Skip guardrails that are not enabled.
|
||||
continue
|
||||
|
||||
if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG):
|
||||
print("[Warning] Skipping unknown guardrail action: ", action)
|
||||
continue
|
||||
|
||||
guardrail = Guardrail(
|
||||
id=g["id"],
|
||||
name=g["name"],
|
||||
content=g["content"],
|
||||
action=GuardrailAction(action),
|
||||
)
|
||||
|
||||
if action == GuardrailAction.BLOCK:
|
||||
blocking_guardrails.append(guardrail)
|
||||
else:
|
||||
logging_guardrails.append(guardrail)
|
||||
|
||||
return GuardrailRuleSet(
|
||||
blocking_guardrails=blocking_guardrails,
|
||||
logging_guardrails=logging_guardrails,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,8 @@ from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
from common.request_context_data import RequestContextData
|
||||
from common.guardrails import Guardrail
|
||||
from common.request_context import RequestContext
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
@@ -81,21 +82,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None:
|
||||
result.raise_for_status()
|
||||
|
||||
|
||||
async def preload_guardrails(context: "RequestContextData") -> None:
|
||||
async def preload_guardrails(context: "RequestContext") -> None:
|
||||
"""
|
||||
Preloads the guardrails for faster checking later.
|
||||
|
||||
Args:
|
||||
context: RequestContextData object.
|
||||
context: RequestContext object.
|
||||
"""
|
||||
if not context.config or not context.config.guardrails:
|
||||
if not context.guardrails:
|
||||
return
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
_preload(context.config.guardrails, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
# Move these calls to a batch preload/validate API.
|
||||
for blocking_guardrail in context.guardrails.blocking_guardrails:
|
||||
task = asyncio.create_task(
|
||||
_preload(blocking_guardrail.content, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
for logging_guadrail in context.guardrails.logging_guardrails:
|
||||
task = asyncio.create_task(
|
||||
_preload(logging_guadrail.content, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
except Exception as e:
|
||||
print(f"Error scheduling preload_guardrails task: {e}")
|
||||
|
||||
@@ -322,14 +330,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
|
||||
messages: List[Dict[str, Any]],
|
||||
guardrails: List[Guardrail],
|
||||
invariant_authorization: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
This calls the batch check API of the Guardrails service.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (str): The guardrails to check against.
|
||||
guardrails (List[Guardrail]): The guardrails to check against.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
|
||||
@@ -340,8 +351,11 @@ async def check_guardrails(
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
f"{url}/api/v1/policy/check/batch",
|
||||
json={
|
||||
"messages": messages,
|
||||
"policies": [g.content for g in guardrails],
|
||||
},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
@@ -351,8 +365,20 @@ async def check_guardrails(
|
||||
raise Exception(
|
||||
f"Guardrails check failed: {result.status_code} - {result.text}"
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
guardrails_result = result.json()
|
||||
|
||||
aggregated_errors = {"errors": []}
|
||||
for res in guardrails_result.get("result", []):
|
||||
aggregated_errors["errors"].extend(res.get("errors", []))
|
||||
|
||||
# check for any error_message
|
||||
if error_message := res.get("error_message"):
|
||||
return {
|
||||
"errors": [
|
||||
{"args": [error_message], "kwargs": {}, "ranges": []}
|
||||
]
|
||||
}
|
||||
return aggregated_errors
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
# make sure runtime errors are also visible in e.g. Explorer
|
||||
|
||||
+92
-47
@@ -5,20 +5,29 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from regex import R
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.config_manager import (
|
||||
GatewayConfig,
|
||||
GatewayConfigManager,
|
||||
GuardrailsInHeader,
|
||||
)
|
||||
from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
from common.request_context import RequestContext
|
||||
from converters.anthropic_to_invariant import (
|
||||
convert_anthropic_to_invariant_message_format,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
push_trace,
|
||||
)
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
@@ -61,6 +70,7 @@ async def anthropic_v1_messages_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
|
||||
):
|
||||
"""Proxy calls to the Anthropic APIs"""
|
||||
headers = {
|
||||
@@ -83,21 +93,26 @@ async def anthropic_v1_messages_gateway(
|
||||
data=request_body,
|
||||
)
|
||||
|
||||
context = RequestContextData(
|
||||
dataset_guardrails = None
|
||||
if dataset_name:
|
||||
# Get the guardrails for the dataset from explorer.
|
||||
dataset_guardrails = await fetch_guardrails_from_explorer(
|
||||
dataset_name, invariant_authorization
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=header_guardrails or dataset_guardrails,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
|
||||
if request_json.get("stream"):
|
||||
return await handle_streaming_response(context, client, anthropic_request)
|
||||
return await handle_non_streaming_response(context, client, anthropic_request)
|
||||
|
||||
|
||||
def create_metadata(
|
||||
context: RequestContextData, response_json: dict[str, Any]
|
||||
context: RequestContext, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Creates metadata for the trace"""
|
||||
metadata = {k: v for k, v in context.request_json.items() if k != "messages"}
|
||||
@@ -108,7 +123,7 @@ def create_metadata(
|
||||
|
||||
|
||||
def combine_request_and_response_messages(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
context: RequestContext, response_json: dict[str, Any]
|
||||
):
|
||||
"""Combine the request and response messages"""
|
||||
messages = []
|
||||
@@ -117,42 +132,63 @@ def combine_request_and_response_messages(
|
||||
{"role": "system", "content": context.request_json.get("system")}
|
||||
)
|
||||
messages.extend(context.request_json.get("messages", []))
|
||||
if len(json_response) > 0:
|
||||
messages.append(json_response)
|
||||
if len(response_json) > 0:
|
||||
messages.append(response_json)
|
||||
return messages
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
messages = combine_request_and_response_messages(context, json_response)
|
||||
# Determine which guardrails to apply based on the action
|
||||
guardrails = (
|
||||
context.guardrails.logging_guardrails
|
||||
if action == GuardrailAction.LOG
|
||||
else context.guardrails.blocking_guardrails
|
||||
)
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
messages = combine_request_and_response_messages(context, response_json)
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_messages,
|
||||
guardrails=context.config.guardrails,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_execution_result.get("errors", [])
|
||||
guardrails_execution_result.get("errors", []), action="block"
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = combine_request_and_response_messages(context, merged_response)
|
||||
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
|
||||
_ = await push_trace(
|
||||
dataset_name=context.dataset_name,
|
||||
messages=[converted_messages],
|
||||
@@ -163,30 +199,32 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
"""Instrumented response for Anthropic API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_string: Optional[str] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -220,10 +258,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
"""Make the request to the Anthropic API."""
|
||||
self.response = await self.client.send(self.anthropic_request)
|
||||
|
||||
try:
|
||||
json_response = self.response.json()
|
||||
response_json = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
@@ -232,11 +271,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
detail=response_json.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
|
||||
self.json_response = json_response
|
||||
self.response_string = json.dumps(json_response)
|
||||
self.response_json = response_json
|
||||
self.response_string = json.dumps(response_json)
|
||||
|
||||
return self._make_response(
|
||||
content=self.response_string,
|
||||
@@ -261,13 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
|
||||
# ensure the response data is available
|
||||
assert self.response is not None, "response is None"
|
||||
assert self.json_response is not None, "json_response is None"
|
||||
assert self.response_json is not None, "response_json is None"
|
||||
assert self.response_string is not None, "response_string is None"
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.response_json,
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
guardrail_response_string = json.dumps(
|
||||
@@ -283,7 +324,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.response_json,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
@@ -300,13 +341,13 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.json_response, guardrails_execution_result
|
||||
self.context, self.response_json, guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
) -> Response:
|
||||
@@ -320,17 +361,19 @@ async def handle_non_streaming_response(
|
||||
return await response.instrumented_request()
|
||||
|
||||
|
||||
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
"""Instrumented streaming response for Anthropic API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
@@ -342,9 +385,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -392,6 +437,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Process the chunk and update the merged_response"""
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
return
|
||||
@@ -400,14 +446,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
process_chunk(decoded_chunk, self.merged_response)
|
||||
|
||||
# on last stream chunk, run output guardrails
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
if "event: message_stop" in decoded_chunk and self.context.guardrails:
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -420,7 +464,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after,
|
||||
# yield an extra error chunk (without preventing the original chunk
|
||||
# to go through after,
|
||||
# so client gets the proper message_stop event still)
|
||||
return ExtraItem(
|
||||
value=f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
@@ -440,12 +485,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
response = InstrumentedAnthropicStreamingResposne(
|
||||
response = InstrumentedAnthropicStreamingResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
anthropic_request=anthropic_request,
|
||||
|
||||
+84
-31
@@ -5,16 +5,27 @@ import json
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.config_manager import (
|
||||
GatewayConfig,
|
||||
GatewayConfigManager,
|
||||
GuardrailsInHeader,
|
||||
)
|
||||
from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
from common.request_context import RequestContext
|
||||
from converters.gemini_to_invariant import convert_request, convert_response
|
||||
from integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
push_trace,
|
||||
)
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
@@ -23,8 +34,6 @@ from integrations.guardrails import (
|
||||
preload_guardrails,
|
||||
check_guardrails,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardrails import check_guardrails, preload_guardrails
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -43,6 +52,7 @@ async def gemini_generate_content_gateway(
|
||||
None, title="Response Format", description="Set to 'sse' for streaming"
|
||||
),
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
|
||||
) -> Response:
|
||||
"""Proxy calls to the Gemini GenerateContent API"""
|
||||
if endpoint not in ["generateContent", "streamGenerateContent"]:
|
||||
@@ -76,14 +86,19 @@ async def gemini_generate_content_gateway(
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
context = RequestContextData(
|
||||
dataset_guardrails = None
|
||||
if dataset_name:
|
||||
# Get the guardrails for the dataset
|
||||
dataset_guardrails = await fetch_guardrails_from_explorer(
|
||||
dataset_name, invariant_authorization
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=header_guardrails or dataset_guardrails,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
|
||||
if alt == "sse" or endpoint == "streamGenerateContent":
|
||||
return await stream_response(
|
||||
context,
|
||||
@@ -98,16 +113,18 @@ async def gemini_generate_content_gateway(
|
||||
|
||||
|
||||
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
"""Instrumented streaming response for Gemini API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
@@ -124,6 +141,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
location: Literal["request", "response"],
|
||||
guardrails_execution_result: dict[str, Any],
|
||||
) -> dict:
|
||||
"""Create a refusal response for the given request or response"""
|
||||
return {
|
||||
"candidates": [
|
||||
{
|
||||
@@ -157,10 +175,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first chunk
|
||||
(for input guardrailing).
|
||||
"""
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -184,6 +205,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""Event generator for streaming responses"""
|
||||
response = await self.client.send(self.gemini_request, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -199,6 +221,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Processes each chunk of the streaming response"""
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
@@ -210,12 +233,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
if (
|
||||
self.merged_response.get("candidates", [])
|
||||
and self.merged_response.get("candidates")[0].get("finishReason", "")
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
and self.context.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -254,7 +278,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def stream_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
) -> Response:
|
||||
@@ -269,7 +293,6 @@ async def stream_response(
|
||||
async def event_generator():
|
||||
async for chunk in response.instrumented_event_generator():
|
||||
yield chunk
|
||||
print("chunk", chunk)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -332,7 +355,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) ->
|
||||
|
||||
|
||||
def create_metadata(
|
||||
context: RequestContextData, response_json: dict[str, Any]
|
||||
context: RequestContext, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Creates metadata for the trace"""
|
||||
metadata = {
|
||||
@@ -352,32 +375,53 @@ def create_metadata(
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, response_json: dict[str, Any]
|
||||
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
# Determine which guardrails to apply based on the action
|
||||
guardrails = (
|
||||
context.guardrails.logging_guardrails
|
||||
if action == GuardrailAction.LOG
|
||||
else context.guardrails.blocking_guardrails
|
||||
)
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
converted_requests = convert_request(context.request_json)
|
||||
converted_responses = convert_response(response_json)
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_requests + converted_responses,
|
||||
guardrails=context.config.guardrails,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
response_json: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_execution_result.get("errors", [])
|
||||
guardrails_execution_result.get("errors", []), action="block"
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=response_json,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
converted_requests = convert_request(context.request_json)
|
||||
converted_responses = convert_response(response_json)
|
||||
|
||||
@@ -391,16 +435,18 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
"""Instrumented response for Gemini API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
@@ -412,10 +458,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first chunk
|
||||
(for input guardrailing).
|
||||
"""
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -463,6 +512,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
"""Makes the request to the Gemini API and return the response"""
|
||||
self.response = await self.client.send(self.gemini_request)
|
||||
|
||||
response_string = self.response.text
|
||||
@@ -489,13 +539,16 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Runs when the request ends."""
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.response_json
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.response_json,
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
@@ -539,7 +592,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
) -> Response:
|
||||
|
||||
+102
-52
@@ -5,14 +5,26 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.config_manager import (
|
||||
GatewayConfig,
|
||||
GatewayConfigManager,
|
||||
GuardrailsInHeader,
|
||||
)
|
||||
from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
from common.request_context import RequestContext
|
||||
from integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
push_trace,
|
||||
)
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
@@ -20,8 +32,6 @@ from integrations.guardrails import (
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -48,6 +58,7 @@ async def openai_chat_completions_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
|
||||
) -> Response:
|
||||
"""Proxy calls to the OpenAI APIs"""
|
||||
headers = {
|
||||
@@ -71,14 +82,19 @@ async def openai_chat_completions_gateway(
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
context = RequestContextData(
|
||||
dataset_guardrails = None
|
||||
if dataset_name:
|
||||
# Get the guardrails for the dataset
|
||||
dataset_guardrails = await fetch_guardrails_from_explorer(
|
||||
dataset_name, invariant_authorization
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
guardrails=header_guardrails or dataset_guardrails,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
|
||||
if request_json.get("stream", False):
|
||||
return await handle_stream_response(
|
||||
context,
|
||||
@@ -91,19 +107,20 @@ async def openai_chat_completions_gateway(
|
||||
|
||||
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
"""
|
||||
Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
|
||||
Does a streaming OpenAI completion request at the core, but also checks guardrails
|
||||
before (concurrent) and after the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
@@ -130,10 +147,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
self.tool_call_mapping_by_index = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first chunk
|
||||
(for input guardrailing).
|
||||
"""
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -163,6 +185,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Processes each chunk of the stream and checks guardrails at the end of the stream"""
|
||||
# process and check each chunk
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
@@ -178,14 +201,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
|
||||
if (
|
||||
"data: [DONE]" in chunk_text
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
if "data: [DONE]" in chunk_text and self.context.guardrails:
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -203,7 +224,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
# push will happen in on_end
|
||||
|
||||
async def on_end(self):
|
||||
"""Sends full merged response to the exploree."""
|
||||
"""Sends full merged response to the explorer."""
|
||||
# don't block on the response from explorer (.create_task)
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
@@ -213,10 +234,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
Actual OpenAI stream response.
|
||||
"""
|
||||
|
||||
"""Actual OpenAI stream response."""
|
||||
response = await self.client.send(self.open_ai_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
@@ -233,7 +251,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def handle_stream_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
@@ -388,7 +406,7 @@ def update_existing_choice_with_delta(
|
||||
|
||||
|
||||
def create_metadata(
|
||||
context: RequestContextData, merged_response: dict[str, Any]
|
||||
context: RequestContext, merged_response: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Creates metadata for the trace"""
|
||||
metadata = {
|
||||
@@ -408,7 +426,7 @@ def create_metadata(
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
@@ -417,12 +435,26 @@ async def push_to_explorer(
|
||||
# or if the guardrails check returned errors.
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
guardrails_errors = guardrails_execution_result.get("errors", [])
|
||||
if guardrails_errors or not (
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_errors, action="block"
|
||||
)
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
if annotations or not (
|
||||
merged_response.get("choices")
|
||||
and merged_response["choices"][0].get("finish_reason")
|
||||
not in FINISH_REASON_TO_PUSH_TRACE
|
||||
):
|
||||
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
@@ -436,18 +468,29 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any] | None = None
|
||||
context: RequestContext,
|
||||
action: GuardrailAction,
|
||||
response_json: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
# Determine which guardrails to apply based on the action
|
||||
guardrails = (
|
||||
context.guardrails.logging_guardrails
|
||||
if action == GuardrailAction.LOG
|
||||
else context.guardrails.blocking_guardrails
|
||||
)
|
||||
|
||||
if json_response is not None:
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
if response_json is not None:
|
||||
messages += [choice["message"] for choice in response_json.get("choices", [])]
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=messages,
|
||||
guardrails=context.config.guardrails,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
@@ -455,35 +498,39 @@ async def get_guardrails_check_result(
|
||||
|
||||
class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
"""
|
||||
Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
|
||||
Does an OpenAI completion request at the core, but also checks guardrails
|
||||
before (concurrent) and after the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# request outputs
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)"""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Checks guardrails in a pipelined fashion, before processing
|
||||
the first chunk (for input guardrailing)
|
||||
"""
|
||||
if self.context.guardrails:
|
||||
# block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context
|
||||
self.context, action=GuardrailAction.BLOCK
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
@@ -516,7 +563,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
self.response = await self.client.send(self.open_ai_request)
|
||||
|
||||
try:
|
||||
self.json_response = self.response.json()
|
||||
self.response_json = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
@@ -525,10 +572,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
|
||||
detail=self.response_json.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
response_string = json.dumps(self.json_response)
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
return Response(
|
||||
@@ -541,23 +588,26 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
async def on_end(self):
|
||||
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
|
||||
|
||||
# these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed)
|
||||
# these two request outputs are guaranteed to be available by the time we reach
|
||||
# this point (after self.request() was executed)
|
||||
# nevertheless, we check for them to avoid any potential issues
|
||||
assert (
|
||||
self.response is not None
|
||||
), "on_end called before 'self.response' was available"
|
||||
assert (
|
||||
self.json_response is not None
|
||||
), "on_end called before 'self.json_response' was available"
|
||||
self.response_json is not None
|
||||
), "on_end called before 'self.response_json' was available"
|
||||
|
||||
# extract original response status code
|
||||
response_code = self.response.status_code
|
||||
|
||||
# if we have guardrails, check the response
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.guardrails:
|
||||
# run guardrails again, this time on request + response
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.response_json,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
@@ -573,7 +623,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.response_json,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
@@ -592,7 +642,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.response_json,
|
||||
# include any guardrailing errors if available
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
@@ -600,7 +650,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
|
||||
|
||||
async def handle_non_stream_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
|
||||
@@ -93,7 +93,12 @@ integration_tests() {
|
||||
fi
|
||||
echo "File successfully downloaded: $FILE"
|
||||
|
||||
TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py"
|
||||
if [[ -z "$INVARIANT_API_KEY" ]]; then
|
||||
echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/integration_test_guardrails_via_file.py"
|
||||
if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then
|
||||
if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then
|
||||
TEST_GUARDRAILS_FILE_PATH=$(realpath "$TEST_GUARDRAILS_FILE_PATH")
|
||||
|
||||
@@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
|
||||
"""Test the Anthropic gateway with Invariant key in the Anthropic key"""
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
+ ";invariant-auth=<not needed for test>"
|
||||
},
|
||||
{"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix},
|
||||
):
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(),
|
||||
|
||||
@@ -12,10 +12,11 @@ from typing import Dict, List
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_anthropic_client
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -26,14 +27,8 @@ class WeatherAgent:
|
||||
|
||||
def __init__(self, gateway_url, push_to_explorer):
|
||||
self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{self.dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
self.client = get_anthropic_client(
|
||||
gateway_url, push_to_explorer, self.dataset_name
|
||||
)
|
||||
self.get_weather_function = {
|
||||
"name": "get_weather",
|
||||
|
||||
@@ -8,10 +8,10 @@ import uuid
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import anthropic
|
||||
from utils import get_anthropic_client
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -26,15 +26,10 @@ async def test_response_without_tool_call(
|
||||
):
|
||||
"""Test the Anthropic gateway without tool calling."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
client = get_anthropic_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
@@ -91,16 +86,7 @@ async def test_streaming_response_without_tool_call(
|
||||
):
|
||||
"""Test the Anthropic gateway without tool calling."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
)
|
||||
client = get_anthropic_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
queries = [
|
||||
|
||||
@@ -60,6 +60,7 @@ services:
|
||||
app-api:
|
||||
container_name: invariant-gateway-test-explorer-app-api
|
||||
image: ghcr.io/invariantlabs-ai/explorer/app-api:latest
|
||||
pull_policy: always
|
||||
platform: linux/amd64
|
||||
depends_on:
|
||||
database:
|
||||
|
||||
@@ -8,9 +8,10 @@ import uuid
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_gemini_client
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
# Pytest plugins
|
||||
@@ -143,18 +144,7 @@ async def test_generate_content_with_tool_call(
|
||||
without streaming.
|
||||
"""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"invariant-authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
},
|
||||
)
|
||||
client = get_gemini_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
|
||||
@@ -10,6 +10,8 @@ from unittest.mock import patch
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_gemini_client
|
||||
|
||||
import pytest
|
||||
import PIL.Image
|
||||
import requests
|
||||
@@ -29,17 +31,8 @@ async def test_generate_content(
|
||||
):
|
||||
"""Test the generate content gateway calls without tool calling."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"invariant-authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
},
|
||||
)
|
||||
client = get_gemini_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": "What is the capital of France?",
|
||||
@@ -115,18 +108,8 @@ async def test_generate_content_with_image(
|
||||
):
|
||||
"""Test that generate content gateway calls work with image."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
client = get_gemini_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"invariant-authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
},
|
||||
)
|
||||
|
||||
image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png"
|
||||
image = PIL.Image.open(image_path)
|
||||
@@ -181,9 +164,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
"""Test the generate content gateway calls with the Invariant API Key in the Gemini Key header."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"GEMINI_API_KEY": gemini_api_key + ";invariant-auth=<not needed for test>"},
|
||||
{"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix},
|
||||
):
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
@@ -194,14 +178,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
|
||||
chat_response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents="What is the capital of Spain?",
|
||||
contents="What is the capital of Denmark?",
|
||||
config={
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the chat response
|
||||
assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper()
|
||||
assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper()
|
||||
expected_assistant_message = chat_response.candidates[0].content.parts[0].text
|
||||
|
||||
# Wait for the trace to be saved
|
||||
@@ -228,7 +212,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "What is the capital of Spain?", "type": "text"}],
|
||||
"content": [{"text": "What is the capital of Denmark?", "type": "text"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
||||
@@ -8,10 +8,11 @@ import time
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
from anthropic import Anthropic, APIStatusError, BadRequestError
|
||||
from anthropic import APIStatusError, BadRequestError
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -32,16 +33,10 @@ async def test_message_content_guardrail_from_file(
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
|
||||
client = Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
client = get_anthropic_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
request = {
|
||||
@@ -161,16 +156,10 @@ async def test_tool_call_guardrail_from_file(
|
||||
}
|
||||
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
|
||||
client = Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
client = get_anthropic_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
if not do_stream:
|
||||
@@ -255,16 +244,10 @@ async def test_input_from_guardrail_from_file(
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
|
||||
client = Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
client = get_anthropic_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
request = {
|
||||
@@ -332,3 +315,276 @@ async def test_input_from_guardrail_from_file(
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
client = get_anthropic_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"',
|
||||
action="block",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content',
|
||||
action="log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
# Ask about the capital of Spain
|
||||
# This should not be blocked by the guardrails from the explorer when we push to explorer
|
||||
# because the file based guardrails are overridden by the explorer guardrails
|
||||
spain_request = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"messages": [{"role": "user", "content": "What is the capital of Spain?"}],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if not do_stream:
|
||||
chat_response = client.messages.create(
|
||||
**spain_request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert "Madrid" in chat_response.content[0].text
|
||||
else:
|
||||
chat_response = client.messages.create(
|
||||
**spain_request,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
merged_content = ""
|
||||
for chunk in chat_response:
|
||||
if chunk.type == "content_block_delta":
|
||||
merged_content += chunk.delta.text
|
||||
assert "Madrid" in merged_content
|
||||
|
||||
# Ask about Shrek
|
||||
# This should be blocked by the guardrails from the explorer
|
||||
user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details."
|
||||
shrek_request = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.messages.create(
|
||||
**shrek_request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
# Only the block guardrail should be triggered here
|
||||
assert "ogre detected in response" in str(exc_info.value)
|
||||
assert "Fiona detected in response" not in str(exc_info.value)
|
||||
else:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
chat_response = client.messages.create(
|
||||
**shrek_request,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for _ in chat_response:
|
||||
pass
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
# Only the block guardrail should be triggered here
|
||||
assert "ogre detected in response" in str(exc_info.value)
|
||||
assert "Fiona detected in response" not in str(exc_info.value)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 2
|
||||
trace_id = traces[1]["id"]
|
||||
|
||||
# Fetch the second trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 2
|
||||
assert (
|
||||
annotations[0]["content"] == "ogre detected in response"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
)
|
||||
assert (
|
||||
annotations[1]["content"] == "Fiona detected in response"
|
||||
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, is_block_action",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
explorer_api_url, gateway_url, do_stream, is_block_action
|
||||
):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
client = get_anthropic_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
|
||||
action="block" if is_block_action else "log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
user_prompt = "Tell me a one sentence pun."
|
||||
request = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if is_block_action:
|
||||
if do_stream:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
chat_response = client.messages.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
else:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.messages.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
|
||||
else:
|
||||
if do_stream:
|
||||
_ = client.messages.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
_ = client.messages.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2 if not is_block_action else 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
if not is_block_action:
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "pun detected in user message"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
if is_block_action
|
||||
else "log"
|
||||
)
|
||||
|
||||
@@ -8,9 +8,10 @@ import time
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
from google import genai
|
||||
|
||||
# Pytest plugins
|
||||
@@ -30,17 +31,10 @@ async def test_message_content_guardrail_from_file(
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"headers": {
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
},
|
||||
client = get_gemini_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
request = {
|
||||
@@ -141,17 +135,10 @@ async def test_tool_call_guardrail_from_file(
|
||||
)
|
||||
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"headers": {
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
},
|
||||
client = get_gemini_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
request = {
|
||||
@@ -244,17 +231,10 @@ async def test_input_from_guardrail_from_file(
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"headers": {
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
},
|
||||
client = get_gemini_client(
|
||||
gateway_url,
|
||||
push_to_explorer,
|
||||
dataset_name,
|
||||
)
|
||||
|
||||
request = {
|
||||
@@ -323,6 +303,259 @@ async def test_input_from_guardrail_from_file(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
client = get_gemini_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"',
|
||||
action="block",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content',
|
||||
action="log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
# Ask about the capital of Spain
|
||||
# This should not be blocked by the guardrails from the explorer when we push to explorer
|
||||
# because the file based guardrails are overridden by the explorer guardrails
|
||||
spain_request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": "What is the capital of Spain?",
|
||||
"config": {
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
}
|
||||
if not do_stream:
|
||||
chat_response = client.models.generate_content(**spain_request)
|
||||
|
||||
assert "Madrid" in chat_response.candidates[0].content.parts[0].text
|
||||
else:
|
||||
chat_response = client.models.generate_content_stream(**spain_request)
|
||||
|
||||
merged_content = ""
|
||||
for chunk in chat_response:
|
||||
if (
|
||||
chunk.candidates
|
||||
and chunk.candidates[0].content
|
||||
and chunk.candidates[0].content.parts
|
||||
):
|
||||
for text_part in chunk.candidates[0].content.parts:
|
||||
merged_content += text_part.text
|
||||
assert "Madrid" in merged_content
|
||||
|
||||
# Ask about Shrek
|
||||
# This should be blocked by the guardrails from the explorer
|
||||
user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details."
|
||||
shrek_request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": user_prompt,
|
||||
"config": {
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
}
|
||||
if not do_stream:
|
||||
with pytest.raises(genai.errors.ClientError) as exc_info:
|
||||
client.models.generate_content(**shrek_request)
|
||||
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
# Only the block guardrail should be triggered here
|
||||
assert "ogre detected in response" in str(exc_info.value)
|
||||
assert "Fiona detected in response" not in str(exc_info.value)
|
||||
else:
|
||||
response = client.models.generate_content_stream(**shrek_request)
|
||||
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The response did not pass the guardrails",
|
||||
"ogre detected in response",
|
||||
],
|
||||
)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 2
|
||||
trace_id = traces[1]["id"]
|
||||
|
||||
# Fetch the second trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 2
|
||||
assert (
|
||||
annotations[0]["content"] == "ogre detected in response"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
)
|
||||
assert (
|
||||
annotations[1]["content"] == "Fiona detected in response"
|
||||
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, is_block_action",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
explorer_api_url, gateway_url, do_stream, is_block_action
|
||||
):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
client = get_gemini_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
|
||||
action="block" if is_block_action else "log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
user_prompt = "Tell me a one sentence pun."
|
||||
request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": user_prompt,
|
||||
"config": {
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
}
|
||||
if is_block_action:
|
||||
if do_stream:
|
||||
chat_response = client.models.generate_content_stream(**request)
|
||||
|
||||
assert_is_streamed_refusal(
|
||||
chat_response,
|
||||
[
|
||||
"[Invariant] The request did not pass the guardrails",
|
||||
"pun detected in user message",
|
||||
],
|
||||
)
|
||||
else:
|
||||
with pytest.raises(genai.errors.ClientError) as exc_info:
|
||||
chat_response = client.models.generate_content(**request)
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
else:
|
||||
if do_stream:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
for _ in response:
|
||||
pass
|
||||
else:
|
||||
_ = client.models.generate_content(**request)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2 if not is_block_action else 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
if not is_block_action:
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "pun detected in user message"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
if is_block_action
|
||||
else "log"
|
||||
)
|
||||
|
||||
|
||||
def is_refusal(chunk):
|
||||
return (
|
||||
len(chunk.candidates) == 1
|
||||
|
||||
@@ -8,10 +8,11 @@ import time
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
from openai import OpenAI, BadRequestError, APIError
|
||||
from openai import BadRequestError, APIError
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -30,17 +31,7 @@ async def test_message_content_guardrail_from_file(
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
request = {
|
||||
"model": "gpt-4o",
|
||||
@@ -161,17 +152,7 @@ async def test_tool_call_guardrail_from_file(
|
||||
}
|
||||
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
@@ -259,17 +240,7 @@ async def test_input_from_guardrail_from_file(
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
request = {
|
||||
"model": "gpt-4o",
|
||||
@@ -349,3 +320,268 @@ async def test_input_from_guardrail_from_file(
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
client = get_open_ai_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "ogre detected in response" if:\n (msg: Message)\n "ogre" in msg.content and msg.role == "assistant"',
|
||||
action="block",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "Fiona detected in response" if:\n (msg: Message)\n "Fiona" in msg.content',
|
||||
action="log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
# Ask about the capital of Spain
|
||||
# This should not be blocked by the guardrails from the explorer when we push to explorer
|
||||
# because the file based guardrails are overridden by the explorer guardrails
|
||||
spain_request = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "What is the capital of Spain?"}],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if not do_stream:
|
||||
chat_response = client.chat.completions.create(
|
||||
**spain_request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert "Madrid" in chat_response.choices[0].message.content
|
||||
else:
|
||||
chat_response = client.chat.completions.create(
|
||||
**spain_request,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
merged_content = ""
|
||||
for chunk in chat_response:
|
||||
if chunk.choices[0].delta.content:
|
||||
merged_content += chunk.choices[0].delta.content
|
||||
assert "Madrid" in merged_content
|
||||
|
||||
# Ask about Shrek
|
||||
# This should be blocked by the guardrails from the explorer
|
||||
user_prompt = "What kind of a creature is Shrek? What is his Shrek's wife's name? Only answer these questions with single sentences, don't add any extra details."
|
||||
shrek_request = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**shrek_request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
# Only the block guardrail should be triggered here
|
||||
assert "ogre detected in response" in str(exc_info.value)
|
||||
assert "Fiona detected in response" not in str(exc_info.value)
|
||||
else:
|
||||
with pytest.raises(APIError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**shrek_request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 2
|
||||
trace_id = traces[1]["id"]
|
||||
|
||||
# Fetch the second trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 2
|
||||
assert (
|
||||
annotations[0]["content"] == "ogre detected in response"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
)
|
||||
assert (
|
||||
annotations[1]["content"] == "Fiona detected in response"
|
||||
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, is_block_action",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
explorer_api_url, gateway_url, do_stream, is_block_action
|
||||
):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
client = get_open_ai_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
|
||||
action="block" if is_block_action else "log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
user_prompt = "Tell me a one sentence pun."
|
||||
request = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if is_block_action:
|
||||
if do_stream:
|
||||
with pytest.raises(APIError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
else:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
else:
|
||||
if do_stream:
|
||||
_ = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
_ = client.chat.completions.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 1 if is_block_action else 2
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
if not is_block_action:
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "pun detected in user message"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
if is_block_action
|
||||
else "log"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Test the guardrails from file with the OpenAI route."""
|
||||
"""Test the guardrails from header with the OpenAI route."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
@@ -136,9 +136,7 @@ raise "Users must not mention the magic phrase 'Abracadabra'" if:
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_invalid_guardrail_in_header(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
async def test_invalid_guardrail_in_header(gateway_url, do_stream, push_to_explorer):
|
||||
"""Test the message content guardrail."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
@@ -178,7 +176,8 @@ raise "Users must not mention the magic phrase 'Abracadabra'" if:
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert "Gateway: Guardrails check failed" in str(
|
||||
print(exc_info.value.message, flush=True)
|
||||
assert "Failed to create policy from policy source." in str(
|
||||
exc_info.value
|
||||
), "guardrails check fails because of an invalid guardrailing rule"
|
||||
assert "illegal statement" in str(
|
||||
|
||||
@@ -9,10 +9,10 @@ import uuid
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_open_ai_client
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
from openai import OpenAI
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -28,17 +28,7 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
without streaming.
|
||||
"""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
@@ -146,17 +136,7 @@ async def test_chat_completion_with_tool_call_with_streaming(
|
||||
while streaming.
|
||||
"""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
|
||||
@@ -11,6 +11,8 @@ from unittest.mock import patch
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_open_ai_client
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
@@ -30,17 +32,7 @@ async def test_chat_completion(
|
||||
):
|
||||
"""Test the chat completions gateway calls without tool calling."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
@@ -103,17 +95,8 @@ async def test_chat_completion_with_image(
|
||||
):
|
||||
"""Test the chat completions gateway works with image."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png"
|
||||
with image_path.open("rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -189,9 +172,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
||||
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"OPENAI_API_KEY": openai_api_key + ";invariant-auth=<not needed for test>"},
|
||||
{"OPENAI_API_KEY": openai_api_key + invariant_key_suffix},
|
||||
):
|
||||
client = OpenAI(
|
||||
http_client=Client(),
|
||||
@@ -252,7 +236,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
|
||||
|
||||
+3
-3
@@ -1,15 +1,15 @@
|
||||
# For Anthropic and OpenAI tests:
|
||||
# For Anthropic and OpenAI tests
|
||||
raise "Madrid detected in the response" if:
|
||||
(msg: Message)
|
||||
"Madrid" in msg.content
|
||||
|
||||
# For Gemini tests:
|
||||
# For Gemini tests
|
||||
raise "Dublin detected in the response" if:
|
||||
(msg: Message)
|
||||
(chunk: str) in text(msg.content)
|
||||
"Dublin" in chunk
|
||||
|
||||
# For all provider tests:
|
||||
# For all provider tests
|
||||
raise "get_capital is called with Germany as argument" if:
|
||||
(call: ToolCall)
|
||||
call is tool:get_capital
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Common utilities for integration tests."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from httpx import AsyncClient, Client
|
||||
from openai import OpenAI
|
||||
from google import genai
|
||||
from anthropic import Anthropic
|
||||
|
||||
|
||||
def get_open_ai_client(
|
||||
gateway_url: str, push_to_explorer: bool, dataset_name: str
|
||||
) -> OpenAI:
|
||||
"""Create an OpenAI client for integration tests."""
|
||||
return OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
|
||||
|
||||
def get_anthropic_client(
|
||||
gateway_url: str, push_to_explorer: bool, dataset_name: str
|
||||
) -> Anthropic:
|
||||
"""Create an Anthropic client for integration tests."""
|
||||
return Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/anthropic",
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_client(
|
||||
gateway_url: str, push_to_explorer: bool, dataset_name: str
|
||||
) -> genai.Client:
|
||||
"""Create a Gemini client for integration tests."""
|
||||
return genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def create_dataset(
|
||||
explorer_api_url: str,
|
||||
invariant_authorization: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a dataset in the Explorer API."""
|
||||
client = Client(base_url=explorer_api_url)
|
||||
response = client.post(
|
||||
"/api/v1/dataset/create",
|
||||
json={"name": dataset_name if dataset_name else f"test-dataset-{uuid.uuid4()}"},
|
||||
headers={"Authorization": invariant_authorization},
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to create dataset: {response.status_code}, {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def add_guardrail_to_dataset(
|
||||
explorer_api_url: str,
|
||||
dataset_id: str,
|
||||
policy: str,
|
||||
action: Literal["block", "log"],
|
||||
invariant_authorization: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a guardrail to a dataset."""
|
||||
client = Client(base_url=explorer_api_url)
|
||||
response = client.post(
|
||||
f"/api/v1/dataset/{dataset_id}/policy",
|
||||
json={
|
||||
"action": action,
|
||||
"policy": policy,
|
||||
"name": f"test-guardrail-{uuid.uuid4()}",
|
||||
},
|
||||
headers={"Authorization": invariant_authorization},
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to add guardrail: {response.status_code}, {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
Reference in New Issue
Block a user