mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-08 14:13:54 +02:00
Add calls to execute logging guardrails before pushing to explorer.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Utility functions for Guardrails execution."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
@@ -351,28 +350,6 @@ async def check_guardrails(
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
try:
|
||||
print(
|
||||
"Hello there this is the request to guardrails: ",
|
||||
json.dumps(
|
||||
{
|
||||
"messages": messages,
|
||||
"policies": [g.content for g in guardrails],
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
"Hello there this is the request to guardrails: ",
|
||||
json.dumps(
|
||||
{
|
||||
"Authorization": invariant_authorization,
|
||||
"Accept": "application/json",
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
result = await client.post(
|
||||
f"{url}/api/v1/policy/check/batch",
|
||||
json={
|
||||
|
||||
+31
-28
@@ -120,7 +120,7 @@ def create_metadata(
|
||||
|
||||
|
||||
def combine_request_and_response_messages(
|
||||
context: RequestContext, json_response: dict[str, Any]
|
||||
context: RequestContext, response_json: dict[str, Any]
|
||||
):
|
||||
"""Combine the request and response messages"""
|
||||
messages = []
|
||||
@@ -129,13 +129,13 @@ def combine_request_and_response_messages(
|
||||
{"role": "system", "content": context.request_json.get("system")}
|
||||
)
|
||||
messages.extend(context.request_json.get("messages", []))
|
||||
if len(json_response) > 0:
|
||||
messages.append(json_response)
|
||||
if len(response_json) > 0:
|
||||
messages.append(response_json)
|
||||
return messages
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContext, action: GuardrailAction, json_response: dict[str, Any]
|
||||
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
# Determine which guardrails to apply based on the action
|
||||
@@ -147,7 +147,7 @@ async def get_guardrails_check_result(
|
||||
if not guardrails:
|
||||
return {}
|
||||
|
||||
messages = combine_request_and_response_messages(context, json_response)
|
||||
messages = combine_request_and_response_messages(context, response_json)
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
|
||||
# Block on the guardrails check
|
||||
@@ -170,10 +170,22 @@ async def push_to_explorer(
|
||||
guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = combine_request_and_response_messages(context, merged_response)
|
||||
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
|
||||
_ = await push_trace(
|
||||
dataset_name=context.dataset_name,
|
||||
messages=[converted_messages],
|
||||
@@ -200,7 +212,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_string: Optional[str] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
@@ -209,7 +221,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.dataset_guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, action=GuardrailAction.BLOCK, json_response={}
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -243,10 +255,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
"""Make the request to the Anthropic API."""
|
||||
self.response = await self.client.send(self.anthropic_request)
|
||||
|
||||
try:
|
||||
json_response = self.response.json()
|
||||
response_json = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
@@ -255,11 +268,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
detail=response_json.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
|
||||
self.json_response = json_response
|
||||
self.response_string = json.dumps(json_response)
|
||||
self.response_json = response_json
|
||||
self.response_string = json.dumps(response_json)
|
||||
|
||||
return self._make_response(
|
||||
content=self.response_string,
|
||||
@@ -284,7 +297,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
|
||||
# ensure the response data is available
|
||||
assert self.response is not None, "response is None"
|
||||
assert self.json_response is not None, "json_response is None"
|
||||
assert self.response_json is not None, "response_json is None"
|
||||
assert self.response_string is not None, "response_string is None"
|
||||
|
||||
if self.context.dataset_guardrails:
|
||||
@@ -292,12 +305,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.json_response,
|
||||
)
|
||||
print(
|
||||
"Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ",
|
||||
guardrails_execution_result,
|
||||
flush=True,
|
||||
response_json=self.response_json,
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
guardrail_response_string = json.dumps(
|
||||
@@ -313,7 +321,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.response_json,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
@@ -330,7 +338,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.json_response, guardrails_execution_result
|
||||
self.context, self.response_json, guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
@@ -378,7 +386,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -440,12 +448,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
)
|
||||
print(
|
||||
"Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ",
|
||||
self.guardrails_execution_result,
|
||||
flush=True,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
|
||||
@@ -290,7 +290,6 @@ async def stream_response(
|
||||
async def event_generator():
|
||||
async for chunk in response.instrumented_event_generator():
|
||||
yield chunk
|
||||
print("chunk", chunk)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -408,6 +407,18 @@ async def push_to_explorer(
|
||||
guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=response_json,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
converted_requests = convert_request(context.request_json)
|
||||
converted_responses = convert_response(response_json)
|
||||
|
||||
|
||||
+27
-14
@@ -152,7 +152,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -203,7 +203,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.merged_response,
|
||||
response_json=self.merged_response,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
@@ -438,6 +438,19 @@ async def push_to_explorer(
|
||||
not in FINISH_REASON_TO_PUSH_TRACE
|
||||
):
|
||||
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
@@ -453,7 +466,7 @@ async def push_to_explorer(
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContext,
|
||||
action: GuardrailAction,
|
||||
json_response: dict[str, Any] | None = None,
|
||||
response_json: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
# Determine which guardrails to apply based on the action
|
||||
@@ -466,8 +479,8 @@ async def get_guardrails_check_result(
|
||||
return {}
|
||||
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
if json_response is not None:
|
||||
messages += [choice["message"] for choice in json_response.get("choices", [])]
|
||||
if response_json is not None:
|
||||
messages += [choice["message"] for choice in response_json.get("choices", [])]
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
@@ -499,7 +512,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
|
||||
# request outputs
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.json_response: Optional[dict[str, Any]] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
@@ -545,7 +558,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
self.response = await self.client.send(self.open_ai_request)
|
||||
|
||||
try:
|
||||
self.json_response = self.response.json()
|
||||
self.response_json = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
@@ -554,10 +567,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
|
||||
detail=self.response_json.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
|
||||
response_string = json.dumps(self.json_response)
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
return Response(
|
||||
@@ -577,8 +590,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
self.response is not None
|
||||
), "on_end called before 'self.response' was available"
|
||||
assert (
|
||||
self.json_response is not None
|
||||
), "on_end called before 'self.json_response' was available"
|
||||
self.response_json is not None
|
||||
), "on_end called before 'self.response_json' was available"
|
||||
|
||||
# extract original response status code
|
||||
response_code = self.response.status_code
|
||||
@@ -589,7 +602,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
action=GuardrailAction.BLOCK,
|
||||
json_response=self.json_response,
|
||||
response_json=self.response_json,
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
@@ -605,7 +618,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.response_json,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
@@ -624,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.json_response,
|
||||
self.response_json,
|
||||
# include any guardrailing errors if available
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
|
||||
@@ -195,14 +195,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
|
||||
chat_response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents="What is the capital of Spain?",
|
||||
contents="What is the capital of Denmark?",
|
||||
config={
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the chat response
|
||||
assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper()
|
||||
assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper()
|
||||
expected_assistant_message = chat_response.candidates[0].content.parts[0].text
|
||||
|
||||
# Wait for the trace to be saved
|
||||
@@ -229,7 +229,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "What is the capital of Spain?", "type": "text"}],
|
||||
"content": [{"text": "What is the capital of Denmark?", "type": "text"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
||||
Reference in New Issue
Block a user