Add Guardrails verification logic for openai route.

This commit is contained in:
Hemang
2025-03-13 17:18:57 +01:00
committed by Hemang Sarkar
parent 5e452af1ae
commit e773cc9f2d
7 changed files with 145 additions and 43 deletions
+3 -3
View File
@@ -9,14 +9,14 @@ services:
- .env
environment:
- DEV_MODE=true
- POLICIES_FILE_PATH=${POLICIES_FILE_PATH:+/srv/resources/policies.py}
- GUARDRAILS_FILE_PATH=${GUARDRAILS_FILE_PATH:+/srv/resources/guardrails.py}
volumes:
- type: bind
source: ./gateway
target: /srv/gateway
- type: bind
source: ${POLICIES_FILE_PATH:-/dev/null}
target: /srv/resources/policies.py
source: ${GUARDRAILS_FILE_PATH:-/dev/null}
target: /srv/resources/guardrails.py
networks:
- invariant-explorer-web
ports:
+14 -14
View File
@@ -10,35 +10,35 @@ class GatewayConfig:
"""Common configurations for the Gateway Server."""
def __init__(self):
self.policies = self._load_policies()
self.guardrails = self._load_guardrails()
def _load_policies(self) -> str:
def _load_guardrails(self) -> str:
"""
Loads and validates policies from the file specified in POLICIES_FILE_PATH.
Returns the policy file content as a string if valid; otherwise, raises an error.
Loads and validates guardrails from the file specified in GUARDRAILS_FILE_PATH.
Returns the guardrails file content as a string if valid; otherwise, raises an error.
"""
policies_file = os.getenv("POLICIES_FILE_PATH", "")
guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "")
if not policies_file:
print("[warning: POLICIES_FILE_PATH is not set. Using empty policies]")
if not guardrails_file:
print("[warning: GUARDRAILS_FILE_PATH is not set. Using empty guardrails]")
return ""
try:
with open(policies_file, "r", encoding="utf-8") as f:
policy_file_content = f.read()
_ = Policy.from_string(policy_file_content)
return policy_file_content
with open(guardrails_file, "r", encoding="utf-8") as f:
guardrails_file_content = f.read()
_ = Policy.from_string(guardrails_file_content)
return guardrails_file_content
except (FileNotFoundError, PermissionError, OSError) as e:
raise ValueError(
f"Error: Unable to read policies file ({policies_file}): {e}"
f"Error: Unable to read guardrails file ({guardrails_file}): {e}"
) from e
except Exception as e:
raise ValueError(f"Invalid policy content in {policies_file}: {e}") from e
raise ValueError(f"Invalid policy content in {guardrails_file}: {e}") from e
def __repr__(self) -> str:
return f"GatewayConfig(policies={repr(self.policies)})"
return f"GatewayConfig(guardrails={repr(self.guardrails)})"
class GatewayConfigManager:
+5
View File
@@ -3,9 +3,14 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from config_manager import GatewayConfig
@dataclass(frozen=True)
class RequestContextData:
"""Request context data class."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
invariant_authorization: Optional[str] = None
config: Optional[GatewayConfig] = None
+24
View File
@@ -5,9 +5,33 @@ from typing import Any, Dict, List
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
from invariant_sdk.types.annotations import AnnotationCreate
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
def create_annotations_from_guardrails_errors(
guardrails_errors: List[dict],
) -> List[AnnotationCreate]:
"""Create Explorer annotations from the guardrails errors."""
annotations = []
for error in guardrails_errors:
content = error.get("args")[0]
address = None
for r in error.get("ranges", []):
# Choose the longest path as the address
if address is None or len(r) > len(address):
address = r
annotations.append(
AnnotationCreate(
content=content,
address=address,
extra_metadata={"source": "guardrails-error"},
)
)
return annotations
async def push_trace(
messages: List[List[Dict[str, Any]]],
dataset_name: str,
+41
View File
@@ -0,0 +1,41 @@
"""Utility functions for Guardrails execution."""
import os
from typing import Any, Dict, List
import httpx
DEFAULT_API_URL = "https://guardrail.invariantnet.com"
async def check_guardrails(
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
) -> Dict[str, Any]:
"""
Checks guardrails on the list of messages.
Args:
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (str): The guardrails to check against.
invariant_authorization (str): Value of the
invariant-authorization header.
Returns:
Dict: Response containing guardrail check results.
"""
client = httpx.AsyncClient()
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
try:
result = await client.post(
f"{url}/api/v1/policy/check",
json={"messages": messages, "policy": guardrails},
headers={
"Authorization": invariant_authorization,
"Accept": "application/json",
},
)
print(f"Guardrail check response: {result.json()}")
return result.json()
except Exception as e:
print(f"Failed to verify guardrails: {e}")
return {"error": str(e)}
+49 -17
View File
@@ -2,7 +2,7 @@
import asyncio
import json
from typing import Any
from typing import Any, Optional
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -12,7 +12,8 @@ from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from integrations.explorer import push_trace
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardails import check_guardrails
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
@@ -68,6 +69,7 @@ async def openai_chat_completions_gateway(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
config=config,
)
if request_json.get("stream", False):
@@ -282,24 +284,30 @@ def update_existing_choice_with_delta(
async def push_to_explorer(
context: RequestContextData, merged_response: dict[str, Any]
context: RequestContextData,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
"""Pushes the merged response to the Invariant Explorer"""
# Only push the trace to explorer if the message is an end turn message
if (
# or if the guardrails check returned errors.
guardrails_execution_result = guardrails_execution_result or {}
guardrails_errors = guardrails_execution_result.get("errors", [])
if guardrails_errors or not (
merged_response.get("choices")
and merged_response["choices"][0].get("finish_reason")
not in FINISH_REASON_TO_PUSH_TRACE
):
return
# Combine the messages from the request body and the choices from the OpenAI response
messages = context.request_json.get("messages", [])
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=context.dataset_name,
messages=[messages],
invariant_authorization=context.invariant_authorization,
)
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
# Combine the messages from the request body and the choices from the OpenAI response
messages = context.request_json.get("messages", [])
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=context.dataset_name,
invariant_authorization=context.invariant_authorization,
messages=[messages],
annotations=[annotations],
)
async def handle_non_streaming_response(
@@ -318,13 +326,37 @@ async def handle_non_streaming_response(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from OpenAI API"),
)
guardrails_execution_result = {}
response_string = json.dumps(json_response)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in json_response.get("choices", [])]
guardrails_execution_result = await check_guardrails(
messages=messages,
guardrails=context.config.guardrails,
invariant_authorization=context.invariant_authorization,
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "The request did not pass the guardrails",
"guadrails_check_result": guardrails_execution_result,
}
)
response_code = 400
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(push_to_explorer(context, json_response))
asyncio.create_task(
push_to_explorer(context, json_response, guardrails_execution_result)
)
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(response.headers),
)
+9 -9
View File
@@ -4,13 +4,13 @@ up() {
docker network create invariant-explorer-web
# Default values
POLICIES_FILE_PATH=""
GUARDRAILS_FILE_PATH=""
# Parse command-line arguments
while [[ "$#" -gt 0 ]]; do
case "$1" in
--policies-file=*)
POLICIES_FILE_PATH="${1#*=}"
--guardrails-file=*)
GUARDRAILS_FILE_PATH="${1#*=}"
;;
*)
echo "Unknown parameter: $1"
@@ -20,21 +20,21 @@ up() {
shift
done
if [[ -n "$POLICIES_FILE_PATH" ]]; then
if [[ -f "$POLICIES_FILE_PATH" ]]; then
POLICIES_FILE_PATH=$(realpath "$POLICIES_FILE_PATH")
if [[ -n "$GUARDRAILS_FILE_PATH" ]]; then
if [[ -f "$GUARDRAILS_FILE_PATH" ]]; then
GUARDRAILS_FILE_PATH=$(realpath "$GUARDRAILS_FILE_PATH")
else
echo "Error: Specified policies file does not exist: $POLICIES_FILE_PATH"
echo "Error: Specified guardrails file does not exist: $GUARDRAILS_FILE_PATH"
exit 1
fi
fi
# Start Docker Compose with the correct environment variable
POLICIES_FILE_PATH="$POLICIES_FILE_PATH" docker compose -f docker-compose.local.yml up -d
GUARDRAILS_FILE_PATH="$GUARDRAILS_FILE_PATH" docker compose -f docker-compose.local.yml up -d
echo "Gateway started at http://localhost:8005/api/v1/gateway/"
echo "See http://localhost:8005/api/v1/gateway/docs for API documentation"
echo "Using Policies File: ${POLICIES_FILE_PATH:-None}"
echo "Using Guardrails File: ${GUARDRAILS_FILE_PATH:-None}"
}
build() {