From e773cc9f2d7b153641afa88043a5de9ee69d22df Mon Sep 17 00:00:00 2001 From: Hemang Date: Thu, 13 Mar 2025 17:18:57 +0100 Subject: [PATCH] Add Guardrails verification logic for openai route. --- docker-compose.local.yml | 6 +-- gateway/common/config_manager.py | 28 +++++------ gateway/common/request_context_data.py | 5 ++ gateway/integrations/explorer.py | 24 ++++++++++ gateway/integrations/guardails.py | 41 ++++++++++++++++ gateway/routes/open_ai.py | 66 +++++++++++++++++++------- run.sh | 18 +++---- 7 files changed, 145 insertions(+), 43 deletions(-) create mode 100644 gateway/integrations/guardails.py diff --git a/docker-compose.local.yml b/docker-compose.local.yml index 4613534..069a926 100644 --- a/docker-compose.local.yml +++ b/docker-compose.local.yml @@ -9,14 +9,14 @@ services: - .env environment: - DEV_MODE=true - - POLICIES_FILE_PATH=${POLICIES_FILE_PATH:+/srv/resources/policies.py} + - GUARDRAILS_FILE_PATH=${GUARDRAILS_FILE_PATH:+/srv/resources/guardrails.py} volumes: - type: bind source: ./gateway target: /srv/gateway - type: bind - source: ${POLICIES_FILE_PATH:-/dev/null} - target: /srv/resources/policies.py + source: ${GUARDRAILS_FILE_PATH:-/dev/null} + target: /srv/resources/guardrails.py networks: - invariant-explorer-web ports: diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 7bdb3e2..c0438d4 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -10,35 +10,35 @@ class GatewayConfig: """Common configurations for the Gateway Server.""" def __init__(self): - self.policies = self._load_policies() + self.guardrails = self._load_guardrails() - def _load_policies(self) -> str: + def _load_guardrails(self) -> str: """ - Loads and validates policies from the file specified in POLICIES_FILE_PATH. - Returns the policy file content as a string if valid; otherwise, raises an error. + Loads and validates guardrails from the file specified in GUARDRAILS_FILE_PATH. + Returns the guardrails file content as a string if valid; otherwise, raises an error. """ - policies_file = os.getenv("POLICIES_FILE_PATH", "") + guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "") - if not policies_file: - print("[warning: POLICIES_FILE_PATH is not set. Using empty policies]") + if not guardrails_file: + print("[warning: GUARDRAILS_FILE_PATH is not set. Using empty guardrails]") return "" try: - with open(policies_file, "r", encoding="utf-8") as f: - policy_file_content = f.read() - _ = Policy.from_string(policy_file_content) - return policy_file_content + with open(guardrails_file, "r", encoding="utf-8") as f: + guardrails_file_content = f.read() + _ = Policy.from_string(guardrails_file_content) + return guardrails_file_content except (FileNotFoundError, PermissionError, OSError) as e: raise ValueError( - f"Error: Unable to read policies file ({policies_file}): {e}" + f"Error: Unable to read guardrails file ({guardrails_file}): {e}" ) from e except Exception as e: - raise ValueError(f"Invalid policy content in {policies_file}: {e}") from e + raise ValueError(f"Invalid policy content in {guardrails_file}: {e}") from e def __repr__(self) -> str: - return f"GatewayConfig(policies={repr(self.policies)})" + return f"GatewayConfig(guardrails={repr(self.guardrails)})" class GatewayConfigManager: diff --git a/gateway/common/request_context_data.py b/gateway/common/request_context_data.py index d59fb03..da967bf 100644 --- a/gateway/common/request_context_data.py +++ b/gateway/common/request_context_data.py @@ -3,9 +3,14 @@ from dataclasses import dataclass from typing import Any, Dict, Optional +from config_manager import GatewayConfig + + @dataclass(frozen=True) class RequestContextData: """Request context data class.""" + request_json: Dict[str, Any] dataset_name: Optional[str] = None invariant_authorization: Optional[str] = None + config: Optional[GatewayConfig] = None diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index 1d6a18f..b557008 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -5,9 +5,33 @@ from typing import Any, Dict, List from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse +from invariant_sdk.types.annotations import AnnotationCreate DEFAULT_API_URL = "https://explorer.invariantlabs.ai" + +def create_annotations_from_guardrails_errors( + guardrails_errors: List[dict], +) -> List[AnnotationCreate]: + """Create Explorer annotations from the guardrails errors.""" + annotations = [] + for error in guardrails_errors: + content = error.get("args")[0] + address = None + for r in error.get("ranges", []): + # Choose the longest path as the address + if address is None or len(r) > len(address): + address = r + annotations.append( + AnnotationCreate( + content=content, + address=address, + extra_metadata={"source": "guardrails-error"}, + ) + ) + return annotations + + async def push_trace( messages: List[List[Dict[str, Any]]], dataset_name: str, diff --git a/gateway/integrations/guardails.py b/gateway/integrations/guardails.py new file mode 100644 index 0000000..a793f0d --- /dev/null +++ b/gateway/integrations/guardails.py @@ -0,0 +1,41 @@ +"""Utility functions for Guardrails execution.""" + +import os +from typing import Any, Dict, List + +import httpx + +DEFAULT_API_URL = "https://guardrail.invariantnet.com" + + +async def check_guardrails( + messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str +) -> Dict[str, Any]: + """ + Checks guardrails on the list of messages. + + Args: + messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. + guardrails (str): The guardrails to check against. + invariant_authorization (str): Value of the + invariant-authorization header. + + Returns: + Dict: Response containing guardrail check results. + """ + client = httpx.AsyncClient() + url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") + try: + result = await client.post( + f"{url}/api/v1/policy/check", + json={"messages": messages, "policy": guardrails}, + headers={ + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + ) + print(f"Guardrail check response: {result.json()}") + return result.json() + except Exception as e: + print(f"Failed to verify guardrails: {e}") + return {"error": str(e)} diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 0528cae..c21474a 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any +from typing import Any, Optional import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -12,7 +12,8 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from integrations.explorer import push_trace +from integrations.explorer import create_annotations_from_guardrails_errors, push_trace +from integrations.guardails import check_guardrails from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData @@ -68,6 +69,7 @@ async def openai_chat_completions_gateway( request_json=request_json, dataset_name=dataset_name, invariant_authorization=invariant_authorization, + config=config, ) if request_json.get("stream", False): @@ -282,24 +284,30 @@ def update_existing_choice_with_delta( async def push_to_explorer( - context: RequestContextData, merged_response: dict[str, Any] + context: RequestContextData, + merged_response: dict[str, Any], + guardrails_execution_result: Optional[dict] = None, ) -> None: - """Pushes the full trace to the Invariant Explorer""" + """Pushes the merged response to the Invariant Explorer""" # Only push the trace to explorer if the message is an end turn message - if ( + # or if the guardrails check returned errors. + guardrails_execution_result = guardrails_execution_result or {} + guardrails_errors = guardrails_execution_result.get("errors", []) + if guardrails_errors or not ( merged_response.get("choices") and merged_response["choices"][0].get("finish_reason") not in FINISH_REASON_TO_PUSH_TRACE ): - return - # Combine the messages from the request body and the choices from the OpenAI response - messages = context.request_json.get("messages", []) - messages += [choice["message"] for choice in merged_response.get("choices", [])] - _ = await push_trace( - dataset_name=context.dataset_name, - messages=[messages], - invariant_authorization=context.invariant_authorization, - ) + annotations = create_annotations_from_guardrails_errors(guardrails_errors) + # Combine the messages from the request body and the choices from the OpenAI response + messages = context.request_json.get("messages", []) + messages += [choice["message"] for choice in merged_response.get("choices", [])] + _ = await push_trace( + dataset_name=context.dataset_name, + invariant_authorization=context.invariant_authorization, + messages=[messages], + annotations=[annotations], + ) async def handle_non_streaming_response( @@ -318,13 +326,37 @@ async def handle_non_streaming_response( status_code=response.status_code, detail=json_response.get("error", "Unknown error from OpenAI API"), ) + + guardrails_execution_result = {} + response_string = json.dumps(json_response) + response_code = response.status_code + + if context.config and context.config.guardrails: + # Block on the guardrails check + messages = list(context.request_json.get("messages", [])) + messages += [choice["message"] for choice in json_response.get("choices", [])] + guardrails_execution_result = await check_guardrails( + messages=messages, + guardrails=context.config.guardrails, + invariant_authorization=context.invariant_authorization, + ) + if guardrails_execution_result.get("errors", []): + response_string = json.dumps( + { + "error": "The request did not pass the guardrails", + "guadrails_check_result": guardrails_execution_result, + } + ) + response_code = 400 if context.dataset_name: # Push to Explorer - don't block on its response - asyncio.create_task(push_to_explorer(context, json_response)) + asyncio.create_task( + push_to_explorer(context, json_response, guardrails_execution_result) + ) return Response( - content=json.dumps(json_response), - status_code=response.status_code, + content=response_string, + status_code=response_code, media_type="application/json", headers=dict(response.headers), ) diff --git a/run.sh b/run.sh index 57a0064..f7bea4a 100755 --- a/run.sh +++ b/run.sh @@ -4,13 +4,13 @@ up() { docker network create invariant-explorer-web # Default values - POLICIES_FILE_PATH="" + GUARDRAILS_FILE_PATH="" # Parse command-line arguments while [[ "$#" -gt 0 ]]; do case "$1" in - --policies-file=*) - POLICIES_FILE_PATH="${1#*=}" + --guardrails-file=*) + GUARDRAILS_FILE_PATH="${1#*=}" ;; *) echo "Unknown parameter: $1" @@ -20,21 +20,21 @@ up() { shift done - if [[ -n "$POLICIES_FILE_PATH" ]]; then - if [[ -f "$POLICIES_FILE_PATH" ]]; then - POLICIES_FILE_PATH=$(realpath "$POLICIES_FILE_PATH") + if [[ -n "$GUARDRAILS_FILE_PATH" ]]; then + if [[ -f "$GUARDRAILS_FILE_PATH" ]]; then + GUARDRAILS_FILE_PATH=$(realpath "$GUARDRAILS_FILE_PATH") else - echo "Error: Specified policies file does not exist: $POLICIES_FILE_PATH" + echo "Error: Specified guardrails file does not exist: $GUARDRAILS_FILE_PATH" exit 1 fi fi # Start Docker Compose with the correct environment variable - POLICIES_FILE_PATH="$POLICIES_FILE_PATH" docker compose -f docker-compose.local.yml up -d + GUARDRAILS_FILE_PATH="$GUARDRAILS_FILE_PATH" docker compose -f docker-compose.local.yml up -d echo "Gateway started at http://localhost:8005/api/v1/gateway/" echo "See http://localhost:8005/api/v1/gateway/docs for API documentation" - echo "Using Policies File: ${POLICIES_FILE_PATH:-None}" + echo "Using Guardrails File: ${GUARDRAILS_FILE_PATH:-None}" } build() {