diff --git a/gateway/integrations/guardails.py b/gateway/integrations/guardails.py index a793f0d..6944e1e 100644 --- a/gateway/integrations/guardails.py +++ b/gateway/integrations/guardails.py @@ -1,13 +1,107 @@ """Utility functions for Guardrails execution.""" +import asyncio import os +import time from typing import Any, Dict, List +from functools import wraps import httpx +from common.request_context_data import RequestContextData DEFAULT_API_URL = "https://guardrail.invariantnet.com" +# Timestamps of last API calls per guardrails string +_guardrails_cache = {} +# Locks per guardrails string +_guardrails_locks = {} + + +def rate_limit(expiration_time: int = 3600): + """ + Decorator to limit API calls to once per expiration_time seconds + per unique guardrails string. + + Args: + expiration_time (int): Time in seconds to cache the guardrails. + """ + + def decorator(func): + @wraps(func) + async def wrapper(guardrails: str, *args, **kwargs): + now = time.time() + + # Get or create a per-guardrail lock + if guardrails not in _guardrails_locks: + _guardrails_locks[guardrails] = asyncio.Lock() + guardrail_lock = _guardrails_locks[guardrails] + + async with guardrail_lock: + last_called = _guardrails_cache.get(guardrails) + + if last_called and (now - last_called < expiration_time): + # Skipping API call: Guardrails '{guardrails}' already + # preloaded within expiration_time + return + + # Update cache timestamp + _guardrails_cache[guardrails] = now + + try: + await func(guardrails, *args, **kwargs) + finally: + _guardrails_locks.pop(guardrails, None) + + return wrapper + + return decorator + + +@rate_limit(3600) # Don't preload the same guardrails string more than once per hour +async def _preload(guardrails: str, invariant_authorization: str) -> None: + """ + Calls the Guardrails API to preload the provided policy for faster checking later. + + Args: + guardrails (str): The guardrails to preload. + invariant_authorization (str): Value of the + invariant-authorization header. + """ + async with httpx.AsyncClient() as client: + url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") + try: + await client.post( + f"{url}/api/v1/policy/load", + json={"policy": guardrails}, + headers={ + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + ) + except Exception as e: + print(f"Failed to load guardrails: {e}") + + +async def preload_guardrails(context: RequestContextData) -> None: + """ + Preloads the guardrails for faster checking later. + + Args: + context: RequestContextData object. + """ + if not context.config or not context.config.guardrails: + return + + try: + task = asyncio.create_task( + _preload(context.config.guardrails, context.invariant_authorization) + ) + asyncio.shield(task) + except Exception as e: + print(f"Error scheduling preload_guardrails task: {e}") + + async def check_guardrails( messages: List[Dict[str, Any]], guardrails: str, invariant_authorization: str ) -> Dict[str, Any]: @@ -23,19 +117,19 @@ async def check_guardrails( Returns: Dict: Response containing guardrail check results. """ - client = httpx.AsyncClient() - url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") - try: - result = await client.post( - f"{url}/api/v1/policy/check", - json={"messages": messages, "policy": guardrails}, - headers={ - "Authorization": invariant_authorization, - "Accept": "application/json", - }, - ) - print(f"Guardrail check response: {result.json()}") - return result.json() - except Exception as e: - print(f"Failed to verify guardrails: {e}") - return {"error": str(e)} + async with httpx.AsyncClient() as client: + url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/") + try: + result = await client.post( + f"{url}/api/v1/policy/check", + json={"messages": messages, "policy": guardrails}, + headers={ + "Authorization": invariant_authorization, + "Accept": "application/json", + }, + ) + print(f"Guardrail check response: {result.json()}") + return result.json() + except Exception as e: + print(f"Failed to verify guardrails: {e}") + return {"error": str(e)} diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index d972911..ce04c95 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -13,7 +13,7 @@ from common.constants import ( IGNORED_HEADERS, ) from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardails import check_guardrails +from integrations.guardails import check_guardrails, preload_guardrails from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData @@ -71,6 +71,7 @@ async def openai_chat_completions_gateway( invariant_authorization=invariant_authorization, config=config, ) + asyncio.create_task(preload_guardrails(context)) if request_json.get("stream", False): return await stream_response(