mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 15:29:43 +02:00
Add Guardrails verification logic for openai route.
This commit is contained in:
@@ -9,14 +9,14 @@ services:
|
||||
- .env
|
||||
environment:
|
||||
- DEV_MODE=true
|
||||
- POLICIES_FILE_PATH=${POLICIES_FILE_PATH:+/srv/resources/policies.py}
|
||||
- GUARDRAILS_FILE_PATH=${GUARDRAILS_FILE_PATH:+/srv/resources/guardrails.py}
|
||||
volumes:
|
||||
- type: bind
|
||||
source: ./gateway
|
||||
target: /srv/gateway
|
||||
- type: bind
|
||||
source: ${POLICIES_FILE_PATH:-/dev/null}
|
||||
target: /srv/resources/policies.py
|
||||
source: ${GUARDRAILS_FILE_PATH:-/dev/null}
|
||||
target: /srv/resources/guardrails.py
|
||||
networks:
|
||||
- invariant-explorer-web
|
||||
ports:
|
||||
|
||||
@@ -10,35 +10,35 @@ class GatewayConfig:
|
||||
"""Common configurations for the Gateway Server."""
|
||||
|
||||
def __init__(self):
|
||||
self.policies = self._load_policies()
|
||||
self.guardrails = self._load_guardrails()
|
||||
|
||||
def _load_policies(self) -> str:
|
||||
def _load_guardrails(self) -> str:
|
||||
"""
|
||||
Loads and validates policies from the file specified in POLICIES_FILE_PATH.
|
||||
Returns the policy file content as a string if valid; otherwise, raises an error.
|
||||
Loads and validates guardrails from the file specified in GUARDRAILS_FILE_PATH.
|
||||
Returns the guardrails file content as a string if valid; otherwise, raises an error.
|
||||
"""
|
||||
policies_file = os.getenv("POLICIES_FILE_PATH", "")
|
||||
guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "")
|
||||
|
||||
if not policies_file:
|
||||
print("[warning: POLICIES_FILE_PATH is not set. Using empty policies]")
|
||||
if not guardrails_file:
|
||||
print("[warning: GUARDRAILS_FILE_PATH is not set. Using empty guardrails]")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(policies_file, "r", encoding="utf-8") as f:
|
||||
policy_file_content = f.read()
|
||||
_ = Policy.from_string(policy_file_content)
|
||||
return policy_file_content
|
||||
with open(guardrails_file, "r", encoding="utf-8") as f:
|
||||
guardrails_file_content = f.read()
|
||||
_ = Policy.from_string(guardrails_file_content)
|
||||
return guardrails_file_content
|
||||
|
||||
except (FileNotFoundError, PermissionError, OSError) as e:
|
||||
raise ValueError(
|
||||
f"Error: Unable to read policies file ({policies_file}): {e}"
|
||||
f"Error: Unable to read guardrails file ({guardrails_file}): {e}"
|
||||
) from e
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid policy content in {policies_file}: {e}") from e
|
||||
raise ValueError(f"Invalid policy content in {guardrails_file}: {e}") from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"GatewayConfig(policies={repr(self.policies)})"
|
||||
return f"GatewayConfig(guardrails={repr(self.guardrails)})"
|
||||
|
||||
|
||||
class GatewayConfigManager:
|
||||
|
||||
@@ -3,9 +3,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from config_manager import GatewayConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContextData:
|
||||
"""Request context data class."""
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
invariant_authorization: Optional[str] = None
|
||||
config: Optional[GatewayConfig] = None
|
||||
|
||||
@@ -5,9 +5,33 @@ from typing import Any, Dict, List
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
||||
from invariant_sdk.types.annotations import AnnotationCreate
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
def create_annotations_from_guardrails_errors(
|
||||
guardrails_errors: List[dict],
|
||||
) -> List[AnnotationCreate]:
|
||||
"""Create Explorer annotations from the guardrails errors."""
|
||||
annotations = []
|
||||
for error in guardrails_errors:
|
||||
content = error.get("args")[0]
|
||||
address = None
|
||||
for r in error.get("ranges", []):
|
||||
# Choose the longest path as the address
|
||||
if address is None or len(r) > len(address):
|
||||
address = r
|
||||
annotations.append(
|
||||
AnnotationCreate(
|
||||
content=content,
|
||||
address=address,
|
||||
extra_metadata={"source": "guardrails-error"},
|
||||
)
|
||||
)
|
||||
return annotations
|
||||
|
||||
|
||||
async def push_trace(
|
||||
messages: List[List[Dict[str, Any]]],
|
||||
dataset_name: str,
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Utility functions for Guardrails execution."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import httpx
|
||||
|
||||
DEFAULT_API_URL = "https://guardrail.invariantnet.com"
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (str): The guardrails to check against.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
|
||||
Returns:
|
||||
Dict: Response containing guardrail check results.
|
||||
"""
|
||||
client = httpx.AsyncClient()
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
return {"error": str(e)}
|
||||
+49
-17
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
@@ -12,7 +12,8 @@ from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import push_trace
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardails import check_guardrails
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
@@ -68,6 +69,7 @@ async def openai_chat_completions_gateway(
|
||||
request_json=request_json,
|
||||
dataset_name=dataset_name,
|
||||
invariant_authorization=invariant_authorization,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if request_json.get("stream", False):
|
||||
@@ -282,24 +284,30 @@ def update_existing_choice_with_delta(
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData, merged_response: dict[str, Any]
|
||||
context: RequestContextData,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
"""Pushes the merged response to the Invariant Explorer"""
|
||||
# Only push the trace to explorer if the message is an end turn message
|
||||
if (
|
||||
# or if the guardrails check returned errors.
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
guardrails_errors = guardrails_execution_result.get("errors", [])
|
||||
if guardrails_errors or not (
|
||||
merged_response.get("choices")
|
||||
and merged_response["choices"][0].get("finish_reason")
|
||||
not in FINISH_REASON_TO_PUSH_TRACE
|
||||
):
|
||||
return
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = context.request_json.get("messages", [])
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
_ = await push_trace(
|
||||
dataset_name=context.dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = context.request_json.get("messages", [])
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
_ = await push_trace(
|
||||
dataset_name=context.dataset_name,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
messages=[messages],
|
||||
annotations=[annotations],
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
@@ -318,13 +326,37 @@ async def handle_non_streaming_response(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(json_response)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=messages,
|
||||
guardrails=context.config.guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "The request did not pass the guardrails",
|
||||
"guadrails_check_result": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(push_to_explorer(context, json_response))
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, json_response, guardrails_execution_result)
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
status_code=response.status_code,
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
@@ -4,13 +4,13 @@ up() {
|
||||
docker network create invariant-explorer-web
|
||||
|
||||
# Default values
|
||||
POLICIES_FILE_PATH=""
|
||||
GUARDRAILS_FILE_PATH=""
|
||||
|
||||
# Parse command-line arguments
|
||||
while [[ "$#" -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--policies-file=*)
|
||||
POLICIES_FILE_PATH="${1#*=}"
|
||||
--guardrails-file=*)
|
||||
GUARDRAILS_FILE_PATH="${1#*=}"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown parameter: $1"
|
||||
@@ -20,21 +20,21 @@ up() {
|
||||
shift
|
||||
done
|
||||
|
||||
if [[ -n "$POLICIES_FILE_PATH" ]]; then
|
||||
if [[ -f "$POLICIES_FILE_PATH" ]]; then
|
||||
POLICIES_FILE_PATH=$(realpath "$POLICIES_FILE_PATH")
|
||||
if [[ -n "$GUARDRAILS_FILE_PATH" ]]; then
|
||||
if [[ -f "$GUARDRAILS_FILE_PATH" ]]; then
|
||||
GUARDRAILS_FILE_PATH=$(realpath "$GUARDRAILS_FILE_PATH")
|
||||
else
|
||||
echo "Error: Specified policies file does not exist: $POLICIES_FILE_PATH"
|
||||
echo "Error: Specified guardrails file does not exist: $GUARDRAILS_FILE_PATH"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Start Docker Compose with the correct environment variable
|
||||
POLICIES_FILE_PATH="$POLICIES_FILE_PATH" docker compose -f docker-compose.local.yml up -d
|
||||
GUARDRAILS_FILE_PATH="$GUARDRAILS_FILE_PATH" docker compose -f docker-compose.local.yml up -d
|
||||
|
||||
echo "Gateway started at http://localhost:8005/api/v1/gateway/"
|
||||
echo "See http://localhost:8005/api/v1/gateway/docs for API documentation"
|
||||
echo "Using Policies File: ${POLICIES_FILE_PATH:-None}"
|
||||
echo "Using Guardrails File: ${GUARDRAILS_FILE_PATH:-None}"
|
||||
}
|
||||
|
||||
build() {
|
||||
|
||||
Reference in New Issue
Block a user