mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-22 06:56:49 +02:00
Add a preload guardrails API call.
This commit is contained in:
@@ -1,13 +1,107 @@
|
||||
"""Utility functions for Guardrails execution."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
DEFAULT_API_URL = "https://guardrail.invariantnet.com"
|
||||
|
||||
|
||||
# Timestamps of last API calls per guardrails string
|
||||
_guardrails_cache = {}
|
||||
# Locks per guardrails string
|
||||
_guardrails_locks = {}
|
||||
|
||||
|
||||
def rate_limit(expiration_time: int = 3600):
|
||||
"""
|
||||
Decorator to limit API calls to once per expiration_time seconds
|
||||
per unique guardrails string.
|
||||
|
||||
Args:
|
||||
expiration_time (int): Time in seconds to cache the guardrails.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(guardrails: str, *args, **kwargs):
|
||||
now = time.time()
|
||||
|
||||
# Get or create a per-guardrail lock
|
||||
if guardrails not in _guardrails_locks:
|
||||
_guardrails_locks[guardrails] = asyncio.Lock()
|
||||
guardrail_lock = _guardrails_locks[guardrails]
|
||||
|
||||
async with guardrail_lock:
|
||||
last_called = _guardrails_cache.get(guardrails)
|
||||
|
||||
if last_called and (now - last_called < expiration_time):
|
||||
# Skipping API call: Guardrails '{guardrails}' already
|
||||
# preloaded within expiration_time
|
||||
return
|
||||
|
||||
# Update cache timestamp
|
||||
_guardrails_cache[guardrails] = now
|
||||
|
||||
try:
|
||||
await func(guardrails, *args, **kwargs)
|
||||
finally:
|
||||
_guardrails_locks.pop(guardrails, None)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@rate_limit(3600) # Don't preload the same guardrails string more than once per hour
|
||||
async def _preload(guardrails: str, invariant_authorization: str) -> None:
|
||||
"""
|
||||
Calls the Guardrails API to preload the provided policy for faster checking later.
|
||||
|
||||
Args:
|
||||
guardrails (str): The guardrails to preload.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
await client.post(
|
||||
f"{url}/api/v1/policy/load",
|
||||
json={"policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to load guardrails: {e}")
|
||||
|
||||
|
||||
async def preload_guardrails(context: RequestContextData) -> None:
|
||||
"""
|
||||
Preloads the guardrails for faster checking later.
|
||||
|
||||
Args:
|
||||
context: RequestContextData object.
|
||||
"""
|
||||
if not context.config or not context.config.guardrails:
|
||||
return
|
||||
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
_preload(context.config.guardrails, context.invariant_authorization)
|
||||
)
|
||||
asyncio.shield(task)
|
||||
except Exception as e:
|
||||
print(f"Error scheduling preload_guardrails task: {e}")
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str
|
||||
) -> Dict[str, Any]:
|
||||
@@ -23,19 +117,19 @@ async def check_guardrails(
|
||||
Returns:
|
||||
Dict: Response containing guardrail check results.
|
||||
"""
|
||||
client = httpx.AsyncClient()
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
return {"error": str(e)}
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check",
|
||||
json={"messages": messages, "policy": guardrails},
|
||||
headers={
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
print(f"Guardrail check response: {result.json()}")
|
||||
return result.json()
|
||||
except Exception as e:
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -13,7 +13,7 @@ from common.constants import (
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardails import check_guardrails
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
|
||||
@@ -71,6 +71,7 @@ async def openai_chat_completions_gateway(
|
||||
invariant_authorization=invariant_authorization,
|
||||
config=config,
|
||||
)
|
||||
asyncio.create_task(preload_guardrails(context))
|
||||
|
||||
if request_json.get("stream", False):
|
||||
return await stream_response(
|
||||
|
||||
Reference in New Issue
Block a user