mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-24 15:54:05 +02:00
Fetch guardrails from explorer. These have higher precedence than than the guardrails from file.
This commit is contained in:
@@ -11,7 +11,7 @@ class GatewayConfig:
|
||||
"""Common configurations for the Gateway Server."""
|
||||
|
||||
def __init__(self):
|
||||
self.guardrails = self._load_guardrails_from_file()
|
||||
self.guardrails_from_file = self._load_guardrails_from_file()
|
||||
|
||||
def _load_guardrails_from_file(self) -> str:
|
||||
"""
|
||||
@@ -48,7 +48,7 @@ class GatewayConfig:
|
||||
raise ValueError(f"Cannot load guardrails, {e}, {e.response.text}") from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"GatewayConfig(guardrails={repr(self.guardrails)})"
|
||||
return f"GatewayConfig(guardrails_from_file={repr(self.guardrails_from_file)})"
|
||||
|
||||
|
||||
class GatewayConfigManager:
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Common guardrails data class."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class GuardrailAction(str, Enum):
|
||||
"""Enum representing the action to be taken for guardrail rules."""
|
||||
|
||||
BLOCK = "block"
|
||||
LOG = "log"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Guardrail:
|
||||
"""Represents a single guardrail rule."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
content: str
|
||||
action: GuardrailAction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatasetGuardrails:
|
||||
"""Grouped guardrail rules separated by their action."""
|
||||
|
||||
blocking_guardrails: List[Guardrail]
|
||||
logging_guardrails: List[Guardrail]
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Common Request context data class."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from common.config_manager import GatewayConfig
|
||||
from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContext:
|
||||
"""Structured context for a request. Must be created via `RequestContext.create()`."""
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
invariant_authorization: Optional[str] = None
|
||||
dataset_guardrails: Optional[DatasetGuardrails] = None
|
||||
config: Dict[str, Any] = None
|
||||
|
||||
_created_via_factory: bool = field(
|
||||
default=False, init=True, repr=False, compare=False
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self._created_via_factory:
|
||||
raise RuntimeError(
|
||||
"RequestContext must be created using RequestContext.create()"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
request_json: Dict[str, Any],
|
||||
dataset_name: Optional[str] = None,
|
||||
invariant_authorization: Optional[str] = None,
|
||||
dataset_guardrails: Optional[DatasetGuardrails] = None,
|
||||
config: Optional[GatewayConfig] = None,
|
||||
) -> "RequestContext":
|
||||
"""Creates a new RequestContext instance, applying default guardrails if needed."""
|
||||
|
||||
# Convert GatewayConfig to a basic dict, excluding guardrails_from_file
|
||||
context_config = {
|
||||
key: value
|
||||
for key, value in (config.__dict__.items() if config else {})
|
||||
if key != "guardrails_from_file"
|
||||
}
|
||||
|
||||
# If no guardrails are configured for the dataset on Explorer,
|
||||
# and the config specifies guardrails_from_file, use that.
|
||||
guardrails = dataset_guardrails
|
||||
if (
|
||||
(
|
||||
not dataset_guardrails
|
||||
or (
|
||||
not dataset_guardrails.blocking_guardrails
|
||||
and not dataset_guardrails.logging_guardrails
|
||||
)
|
||||
)
|
||||
and config
|
||||
and config.guardrails_from_file
|
||||
):
|
||||
# TODO: Support logging guardrails via file.
|
||||
guardrails = DatasetGuardrails(
|
||||
blocking_guardrails=[
|
||||
Guardrail(
|
||||
id="default",
|
||||
name="default",
|
||||
content=config.guardrails_from_file,
|
||||
action=GuardrailAction.BLOCK,
|
||||
)
|
||||
],
|
||||
logging_guardrails=[],
|
||||
)
|
||||
|
||||
return cls(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
dataset_guardrails=guardrails,
|
||||
config=context_config,
|
||||
_created_via_factory=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RequestContext("
|
||||
f"request_json={self.request_json}, "
|
||||
f"dataset_name={self.dataset_name}, "
|
||||
f"invariant_authorization={self.invariant_authorization}, "
|
||||
f"dataset_guardrails={self.dataset_guardrails}, "
|
||||
f"config={self.config})"
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Common Request context data class."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from common.config_manager import GatewayConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContextData:
|
||||
"""Request context data class."""
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
invariant_authorization: Optional[str] = None
|
||||
config: Optional[GatewayConfig] = None
|
||||
@@ -3,10 +3,13 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from common.guardrails import DatasetGuardrails, Guardrail, GuardrailAction
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
||||
from invariant_sdk.types.annotations import AnnotationCreate
|
||||
|
||||
import httpx
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
@@ -91,3 +94,79 @@ async def push_trace(
|
||||
except Exception as e:
|
||||
print(f"Failed to push trace: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def fetch_guardrails_from_explorer(
|
||||
dataset_name: str, invariant_authorization: str
|
||||
) -> DatasetGuardrails:
|
||||
"""Get the guardrails for the dataset.
|
||||
|
||||
Returns:
|
||||
DatasetGuardrails: The guardrails for the dataset grouped by their action.
|
||||
"""
|
||||
|
||||
# TODO: Implement a single API in explorer backend which can return
|
||||
# dataset details without requiring a username.
|
||||
|
||||
client = httpx.AsyncClient(
|
||||
base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"),
|
||||
headers={
|
||||
"Invariant-Authorization": invariant_authorization,
|
||||
},
|
||||
)
|
||||
|
||||
# Get the user details.
|
||||
user_info_response = await client.get("/api/v1/user/info")
|
||||
if user_info_response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
|
||||
)
|
||||
user_details = user_info_response.json()
|
||||
username = user_details["username"]
|
||||
|
||||
# Get the dataset policies.
|
||||
policies_response = await client.get(
|
||||
f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy"
|
||||
)
|
||||
if policies_response.status_code != 200:
|
||||
if policies_response.status_code == 404:
|
||||
# If the dataset does not exist, return empty guardrails.
|
||||
return DatasetGuardrails(
|
||||
blocking_guardrails=[],
|
||||
logging_guardrails=[],
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}"
|
||||
)
|
||||
policies_details = policies_response.json()
|
||||
guardrails = policies_details.get("policies", [])
|
||||
|
||||
blocking_guardrails = []
|
||||
logging_guardrails = []
|
||||
for g in guardrails:
|
||||
action = g["action"]
|
||||
|
||||
if not g["enabled"]:
|
||||
# Skip guardrails that are not enabled.
|
||||
continue
|
||||
|
||||
if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG):
|
||||
print("[Warning] Skipping unknown guardrail action: ", action)
|
||||
continue
|
||||
|
||||
guardrail = Guardrail(
|
||||
id=g["id"],
|
||||
name=g["name"],
|
||||
content=g["content"],
|
||||
action=GuardrailAction(action),
|
||||
)
|
||||
|
||||
if action == GuardrailAction.BLOCK:
|
||||
blocking_guardrails.append(guardrail)
|
||||
else:
|
||||
logging_guardrails.append(guardrail)
|
||||
|
||||
return DatasetGuardrails(
|
||||
blocking_guardrails=blocking_guardrails,
|
||||
logging_guardrails=logging_guardrails,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""Utility functions for Guardrails execution."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
from common.request_context_data import RequestContextData
|
||||
from common.guardrails import Guardrail
|
||||
from common.request_context import RequestContext
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
@@ -81,21 +83,28 @@ async def _preload(guardrails: str, invariant_authorization: str) -> None:
|
||||
result.raise_for_status()
|
||||
|
||||
|
||||
async def preload_guardrails(context: "RequestContextData") -> None:
|
||||
async def preload_guardrails(context: "RequestContext") -> None:
|
||||
"""
|
||||
Preloads the guardrails for faster checking later.
|
||||
|
||||
Args:
|
||||
context: RequestContextData object.
|
||||
context: RequestContext object.
|
||||
"""
|
||||
if not context.config or not context.config.guardrails:
|
||||
if not context.dataset_guardrails:
|
||||
return
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
_preload(context.config.guardrails, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
# Move these calls to a batch preload/validate API.
|
||||
for blocking_guardrail in context.dataset_guardrails.blocking_guardrails:
|
||||
task = asyncio.create_task(
|
||||
_preload(blocking_guardrail.content, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
for logging_guadrail in context.dataset_guardrails.logging_guardrails:
|
||||
task = asyncio.create_task(
|
||||
_preload(logging_guadrail.content, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
except Exception as e:
|
||||
print(f"Error scheduling preload_guardrails task: {e}")
|
||||
|
||||
@@ -322,14 +331,17 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
|
||||
messages: List[Dict[str, Any]],
|
||||
guardrails: List[Guardrail],
|
||||
invariant_authorization: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
This calls the batch check API of the Guardrails service.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (str): The guardrails to check against.
|
||||
guardrails (List[Guardrail]): The guardrails to check against.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
|
||||
@@ -339,9 +351,34 @@ async def check_guardrails(
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
print(
|
||||
"Hello there this is the request to guardrails: ",
|
||||
json.dumps(
|
||||
{
|
||||
"messages": messages,
|
||||
"policies": [g.content for g in guardrails],
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
"Hello there this is the request to guardrails: ",
|
||||
json.dumps(
|
||||
{
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
f"{url}/api/v1/policy/check/batch",
|
||||
json={
|
||||
"messages": messages,
|
||||
"policies": [g.content for g in guardrails],
|
||||
},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
@@ -352,7 +389,12 @@ async def check_guardrails(
|
||||
f"Guardrails check failed: {result.status_code} - {result.text}"
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
|
||||
guardrails_result = result.json()
|
||||
aggregated_errors = {"errors": []}
|
||||
for res in guardrails_result.get("result", []):
|
||||
aggregated_errors["errors"].extend(res.get("errors", []))
|
||||
return aggregated_errors
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
+68
-29
@@ -14,11 +14,16 @@ from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from common.request_context_data import RequestContextData
|
||||
from common.guardrails import GuardrailAction
|
||||
from common.request_context import RequestContext
|
||||
from converters.anthropic_to_invariant import (
|
||||
convert_anthropic_to_invariant_message_format,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
push_trace,
|
||||
)
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
@@ -83,10 +88,17 @@ async def anthropic_v1_messages_gateway(
|
||||
data=request_body,
|
||||
)
|
||||
|
||||
context = RequestContextData(
|
||||
dataset_guardrails = None
|
||||
if dataset_name:
|
||||
# Get the guardrails for the dataset from explorer.
|
||||
dataset_guardrails = await fetch_guardrails_from_explorer(
|
||||
dataset_name, invariant_authorization
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
dataset_guardrails=dataset_guardrails,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
@@ -97,7 +109,7 @@ async def anthropic_v1_messages_gateway(
|
||||
|
||||
|
||||
def create_metadata(
|
||||
context: RequestContextData, response_json: dict[str, Any]
|
||||
context: RequestContext, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Creates metadata for the trace"""
|
||||
metadata = {k: v for k, v in context.request_json.items() if k != "messages"}
|
||||
@@ -108,7 +120,7 @@ def create_metadata(
|
||||
|
||||
|
||||
def combine_request_and_response_messages(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
context: RequestContext, json_response: dict[str, Any]
|
||||
):
|
||||
"""Combine the request and response messages"""
|
||||
messages = []
|
||||
@@ -123,23 +135,32 @@ def combine_request_and_response_messages(
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
context: RequestContext, action: GuardrailAction, json_response: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
# Determine which guardrails to apply based on the action
|
||||
guardrails = (
|
||||
context.dataset_guardrails.logging_guardrails
|
||||
if action == GuardrailAction.LOG
|
||||
else context.dataset_guardrails.blocking_guardrails
|
||||
)
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
messages = combine_request_and_response_messages(context, json_response)
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_messages,
|
||||
guardrails=context.config.guardrails,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
@@ -163,14 +184,16 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
"""Instrumented response for Anthropic API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
@@ -184,9 +207,9 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.dataset_guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
self.context, action=GuardrailAction.BLOCK, json_response={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -264,10 +287,17 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
assert self.json_response is not None, "json_response is None"
|
||||
assert self.response_string is not None, "response_string is None"
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.dataset_guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.json_response,
|
||||
)
|
||||
print(
|
||||
"Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ",
|
||||
guardrails_execution_result,
|
||||
flush=True,
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
guardrail_response_string = json.dumps(
|
||||
@@ -306,7 +336,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
) -> Response:
|
||||
@@ -320,17 +350,19 @@ async def handle_non_streaming_response(
|
||||
return await response.instrumented_request()
|
||||
|
||||
|
||||
class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
"""Instrumented streaming response for Anthropic API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
@@ -342,9 +374,11 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.dataset_guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -392,6 +426,7 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Process the chunk and update the merged_response"""
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
return
|
||||
@@ -400,14 +435,17 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
process_chunk(decoded_chunk, self.merged_response)
|
||||
|
||||
# on last stream chunk, run output guardrails
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
if "event: message_stop" in decoded_chunk and self.context.dataset_guardrails:
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
)
|
||||
print(
|
||||
"Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ",
|
||||
self.guardrails_execution_result,
|
||||
flush=True,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -420,7 +458,8 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after,
|
||||
# yield an extra error chunk (without preventing the original chunk
|
||||
# to go through after,
|
||||
# so client gets the proper message_stop event still)
|
||||
return ExtraItem(
|
||||
value=f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
@@ -440,12 +479,12 @@ class InstrumentedAnthropicStreamingResposne(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
response = InstrumentedAnthropicStreamingResposne(
|
||||
response = InstrumentedAnthropicStreamingResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
anthropic_request=anthropic_request,
|
||||
|
||||
+62
-24
@@ -14,9 +14,14 @@ from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from common.request_context_data import RequestContextData
|
||||
from common.guardrails import GuardrailAction
|
||||
from common.request_context import RequestContext
|
||||
from converters.gemini_to_invariant import convert_request, convert_response
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
push_trace,
|
||||
)
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
@@ -76,10 +81,17 @@ async def gemini_generate_content_gateway(
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
context = RequestContextData(
|
||||
dataset_guardrails = None
|
||||
if dataset_name:
|
||||
# Get the guardrails for the dataset
|
||||
dataset_guardrails = await fetch_guardrails_from_explorer(
|
||||
dataset_name, invariant_authorization
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
dataset_guardrails=dataset_guardrails,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
@@ -98,16 +110,18 @@ async def gemini_generate_content_gateway(
|
||||
|
||||
|
||||
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
"""Instrumented streaming response for Gemini API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
@@ -124,6 +138,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
location: Literal["request", "response"],
|
||||
guardrails_execution_result: dict[str, Any],
|
||||
) -> dict:
|
||||
"""Create a refusal response for the given request or response"""
|
||||
return {
|
||||
"candidates": [
|
||||
{
|
||||
@@ -157,10 +172,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first chunk
|
||||
(for input guardrailing).
|
||||
"""
|
||||
if self.context.dataset_guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -184,6 +202,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""Event generator for streaming responses"""
|
||||
response = await self.client.send(self.gemini_request, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -199,6 +218,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Processes each chunk of the streaming response"""
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
@@ -210,12 +230,13 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
if (
|
||||
self.merged_response.get("candidates", [])
|
||||
and self.merged_response.get("candidates")[0].get("finishReason", "")
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
and self.context.dataset_guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -254,7 +275,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def stream_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
) -> Response:
|
||||
@@ -332,7 +353,7 @@ def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) ->
|
||||
|
||||
|
||||
def create_metadata(
|
||||
context: RequestContextData, response_json: dict[str, Any]
|
||||
context: RequestContext, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Creates metadata for the trace"""
|
||||
metadata = {
|
||||
@@ -352,23 +373,32 @@ def create_metadata(
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, response_json: dict[str, Any]
|
||||
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
# Determine which guardrails to apply based on the action
|
||||
guardrails = (
|
||||
context.dataset_guardrails.logging_guardrails
|
||||
if action == GuardrailAction.LOG
|
||||
else context.dataset_guardrails.blocking_guardrails
|
||||
)
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
converted_requests = convert_request(context.request_json)
|
||||
converted_responses = convert_response(response_json)
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_requests + converted_responses,
|
||||
guardrails=context.config.guardrails,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
response_json: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
@@ -391,16 +421,18 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
"""Instrumented response for Gemini API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
@@ -412,10 +444,13 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first chunk
|
||||
(for input guardrailing).
|
||||
"""
|
||||
if self.context.dataset_guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -463,6 +498,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
"""Makes the request to the Gemini API and return the response"""
|
||||
self.response = await self.client.send(self.gemini_request)
|
||||
|
||||
response_string = self.response.text
|
||||
@@ -492,10 +528,12 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.dataset_guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.response_json
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
response_json=self.response_json,
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
@@ -539,7 +577,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
) -> Response:
|
||||
|
||||
+66
-35
@@ -14,8 +14,13 @@ from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from common.request_context_data import RequestContextData
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from common.guardrails import GuardrailAction
|
||||
from common.request_context import RequestContext
|
||||
from integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
push_trace,
|
||||
)
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
@@ -72,10 +77,17 @@ async def openai_chat_completions_gateway(
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
context = RequestContextData(
|
||||
dataset_guardrails = None
|
||||
if dataset_name:
|
||||
# Get the guardrails for the dataset
|
||||
dataset_guardrails = await fetch_guardrails_from_explorer(
|
||||
dataset_name, invariant_authorization
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
dataset_guardrails=dataset_guardrails,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
@@ -92,19 +104,20 @@ async def openai_chat_completions_gateway(
|
||||
|
||||
class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
"""
|
||||
Does a streaming OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
|
||||
Does a streaming OpenAI completion request at the core, but also checks guardrails
|
||||
before (concurrent) and after the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
@@ -131,10 +144,15 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
self.tool_call_mapping_by_index = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first chunk
|
||||
(for input guardrailing).
|
||||
"""
|
||||
if self.context.dataset_guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -164,6 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
"""Processes each chunk of the stream and checks guardrails at the end of the stream"""
|
||||
# process and check each chunk
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
@@ -179,14 +198,12 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
# check guardrails at the end of the stream (on the '[DONE]' SSE chunk.)
|
||||
if (
|
||||
"data: [DONE]" in chunk_text
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
if "data: [DONE]" in chunk_text and self.context.dataset_guardrails:
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -214,10 +231,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
Actual OpenAI stream response.
|
||||
"""
|
||||
|
||||
"""Actual OpenAI stream response."""
|
||||
response = await self.client.send(self.open_ai_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
@@ -234,7 +248,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def handle_stream_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
@@ -389,7 +403,7 @@ def update_existing_choice_with_delta(
|
||||
|
||||
|
||||
def create_metadata(
|
||||
context: RequestContextData, merged_response: dict[str, Any]
|
||||
context: RequestContext, merged_response: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Creates metadata for the trace"""
|
||||
metadata = {
|
||||
@@ -409,7 +423,7 @@ def create_metadata(
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
@@ -437,18 +451,28 @@ async def push_to_explorer(
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any] | None = None
|
||||
context: RequestContext,
|
||||
action: GuardrailAction,
|
||||
json_response: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
# Determine which guardrails to apply based on the action
|
||||
guardrails = (
|
||||
context.dataset_guardrails.logging_guardrails
|
||||
if action == GuardrailAction.LOG
|
||||
else context.dataset_guardrails.blocking_guardrails
|
||||
)
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
if json_response is not None:
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=messages,
|
||||
guardrails=context.config.guardrails,
|
||||
guardrails=guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
@@ -456,19 +480,20 @@ async def get_guardrails_check_result(
|
||||
|
||||
class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
"""
|
||||
Does an OpenAI completion request at the core, but also checks guardrails before (concurrent) and after the request.
|
||||
Does an OpenAI completion request at the core, but also checks guardrails
|
||||
before (concurrent) and after the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request parameters
|
||||
self.context: RequestContextData = context
|
||||
self.context: RequestContext = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
@@ -480,11 +505,14 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Checks guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)"""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
"""
|
||||
Checks guardrails in a pipelined fashion, before processing
|
||||
the first chunk (for input guardrailing)
|
||||
"""
|
||||
if self.context.dataset_guardrails:
|
||||
# block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context
|
||||
self.context, action=GuardrailAction.BLOCK
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
@@ -542,7 +570,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
async def on_end(self):
|
||||
"""Postprocesses the OpenAI response and potentially replace it with a guardrails error."""
|
||||
|
||||
# these two request outputs are guaranteed to be available by the time we reach this point (after self.request() was executed)
|
||||
# these two request outputs are guaranteed to be available by the time we reach
|
||||
# this point (after self.request() was executed)
|
||||
# nevertheless, we check for them to avoid any potential issues
|
||||
assert (
|
||||
self.response is not None
|
||||
@@ -555,10 +584,12 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
response_code = self.response.status_code
|
||||
|
||||
# if we have guardrails, check the response
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
if self.context.dataset_guardrails:
|
||||
# run guardrails again, this time on request + response
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.json_response
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.json_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
@@ -601,7 +632,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
|
||||
|
||||
async def handle_non_stream_response(
|
||||
context: RequestContextData,
|
||||
context: RequestContext,
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
) -> Response:
|
||||
|
||||
@@ -93,6 +93,11 @@ integration_tests() {
|
||||
fi
|
||||
echo "File successfully downloaded: $FILE"
|
||||
|
||||
if [[ -z "$INVARIANT_API_KEY" ]]; then
|
||||
echo "Error: INVARIANT_API_KEY env var is not set. This is required to run integration tests."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TEST_GUARDRAILS_FILE_PATH="tests/integration/resources/guardrails/find_capital_guardrails.py"
|
||||
if [[ -n "$TEST_GUARDRAILS_FILE_PATH" ]]; then
|
||||
if [[ -f "$TEST_GUARDRAILS_FILE_PATH" ]]; then
|
||||
|
||||
@@ -27,12 +27,10 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
|
||||
"""Test the Anthropic gateway with Invariant key in the Anthropic key"""
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
+ ";invariant-auth=<not needed for test>"
|
||||
},
|
||||
{"ANTHROPIC_API_KEY": anthropic_api_key + invariant_key_suffix},
|
||||
):
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(),
|
||||
|
||||
@@ -26,7 +26,7 @@ class WeatherAgent:
|
||||
|
||||
def __init__(self, gateway_url, push_to_explorer):
|
||||
self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
|
||||
@@ -26,7 +26,7 @@ async def test_response_without_tool_call(
|
||||
):
|
||||
"""Test the Anthropic gateway without tool calling."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
@@ -91,7 +91,7 @@ async def test_streaming_response_without_tool_call(
|
||||
):
|
||||
"""Test the Anthropic gateway without tool calling."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
|
||||
@@ -151,7 +151,7 @@ async def test_generate_content_with_tool_call(
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"invariant-authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
},
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ async def test_generate_content(
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"invariant-authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
},
|
||||
)
|
||||
@@ -123,7 +123,7 @@ async def test_generate_content_with_image(
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
"headers": {
|
||||
"invariant-authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
},
|
||||
)
|
||||
@@ -181,9 +181,10 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
"""Test the generate content gateway calls with the Invariant API Key in the Gemini Key header."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"GEMINI_API_KEY": gemini_api_key + ";invariant-auth=<not needed for test>"},
|
||||
{"GEMINI_API_KEY": gemini_api_key + invariant_key_suffix},
|
||||
):
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
|
||||
@@ -32,7 +32,7 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
@@ -150,7 +150,7 @@ async def test_chat_completion_with_tool_call_with_streaming(
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
|
||||
@@ -34,7 +34,7 @@ async def test_chat_completion(
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
@@ -107,7 +107,7 @@ async def test_chat_completion_with_image(
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
@@ -189,9 +189,10 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
||||
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"OPENAI_API_KEY": openai_api_key + ";invariant-auth=<not needed for test>"},
|
||||
{"OPENAI_API_KEY": openai_api_key + invariant_key_suffix},
|
||||
):
|
||||
client = OpenAI(
|
||||
http_client=Client(),
|
||||
@@ -252,7 +253,7 @@ async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
|
||||
|
||||
Reference in New Issue
Block a user