diff --git a/README.md b/README.md index 64b2a70..f50bf8a 100644 --- a/README.md +++ b/README.md @@ -395,4 +395,4 @@ To run a subset of the integration tests, execute: ```bash bash run.sh integration-tests open_ai/test_chat_with_tool_call.py -``` +``` \ No newline at end of file diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index add8839..00336e8 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -5,6 +5,7 @@ import os import time from functools import wraps from typing import Any +from datetime import datetime import httpx from fastapi import HTTPException @@ -16,12 +17,21 @@ from gateway.common.authorization import ( ) from gateway.common.guardrails import Guardrail +import uuid + # Timestamps of last API calls per guardrails string _guardrails_cache = {} # Locks per guardrails string _guardrails_locks = {} +# Temporary session ID generation +def generate_session_id(): + return str(uuid.uuid4()) + +session_id = generate_session_id() + + def rate_limit(expiration_time: int = 3600): """ Decorator to limit API calls to once per expiration_time seconds @@ -136,6 +146,7 @@ async def check_guardrails( """ async with httpx.AsyncClient() as client: url = os.getenv("GUARDRAILS_API_URL", DEFAULT_API_URL).rstrip("/") + try: result = await client.post( f"{url}/api/v1/policy/check/batch", @@ -143,10 +154,12 @@ async def check_guardrails( "messages": messages, "policies": [g.content for g in guardrails], "parameters": context.guardrails_parameters or {}, + "dataset_name": context.dataset_name, }, headers={ "Authorization": context.get_guardrailing_authorization(), "Accept": CONTENT_TYPE_JSON, + "X-Session-Id": session_id, }, timeout=5, ) diff --git a/uv.lock b/uv.lock index 2b1418d..795a894 100644 --- a/uv.lock +++ b/uv.lock @@ -249,7 +249,7 @@ wheels = [ [[package]] name = "invariant-gateway" -version = "0.0.5.2" +version = "0.0.8" source = { editable = "." } dependencies = [ { name = "fastapi" },