Add calls to execute logging guardrails before pushing to explorer.

This commit is contained in:
Hemang
2025-04-01 14:41:18 +02:00
committed by Hemang Sarkar
parent 050ec1ba58
commit 750c83d3f8
5 changed files with 73 additions and 69 deletions
-23
View File
@@ -1,7 +1,6 @@
"""Utility functions for Guardrails execution."""
import asyncio
import json
import os
import time
from typing import Any, Dict, List
@@ -351,28 +350,6 @@ async def check_guardrails(
async with httpx.AsyncClient() as client:
url = os.getenv("GUADRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
try:
print(
"Hello there this is the request to guardrails: ",
json.dumps(
{
"messages": messages,
"policies": [g.content for g in guardrails],
},
indent=2,
),
flush=True,
)
print(
"Hello there this is the request to guardrails: ",
json.dumps(
{
"Authorization": invariant_authorization,
"Accept": "application/json",
},
indent=2,
),
flush=True,
)
result = await client.post(
f"{url}/api/v1/policy/check/batch",
json={
+31 -28
View File
@@ -120,7 +120,7 @@ def create_metadata(
def combine_request_and_response_messages(
context: RequestContext, json_response: dict[str, Any]
context: RequestContext, response_json: dict[str, Any]
):
"""Combine the request and response messages"""
messages = []
@@ -129,13 +129,13 @@ def combine_request_and_response_messages(
{"role": "system", "content": context.request_json.get("system")}
)
messages.extend(context.request_json.get("messages", []))
if len(json_response) > 0:
messages.append(json_response)
if len(response_json) > 0:
messages.append(response_json)
return messages
async def get_guardrails_check_result(
context: RequestContext, action: GuardrailAction, json_response: dict[str, Any]
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
# Determine which guardrails to apply based on the action
@@ -147,7 +147,7 @@ async def get_guardrails_check_result(
if not guardrails:
return {}
messages = combine_request_and_response_messages(context, json_response)
messages = combine_request_and_response_messages(context, response_json)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
# Block on the guardrails check
@@ -170,10 +170,22 @@ async def push_to_explorer(
guardrails_execution_result.get("errors", [])
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", [])
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
# Combine the messages from the request body and Anthropic response
messages = combine_request_and_response_messages(context, merged_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=context.dataset_name,
messages=[converted_messages],
@@ -200,7 +212,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
# response data
self.response: Optional[httpx.Response] = None
self.response_string: Optional[str] = None
self.json_response: Optional[dict[str, Any]] = None
self.response_json: Optional[dict[str, Any]] = None
# guardrailing response (if any)
self.guardrails_execution_result = {}
@@ -209,7 +221,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.dataset_guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, action=GuardrailAction.BLOCK, json_response={}
self.context, action=GuardrailAction.BLOCK, response_json={}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -243,10 +255,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
)
async def request(self):
"""Make the request to the Anthropic API."""
self.response = await self.client.send(self.anthropic_request)
try:
json_response = self.response.json()
response_json = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
@@ -255,11 +268,11 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
detail=response_json.get("error", "Unknown error from Anthropic"),
)
self.json_response = json_response
self.response_string = json.dumps(json_response)
self.response_json = response_json
self.response_string = json.dumps(response_json)
return self._make_response(
content=self.response_string,
@@ -284,7 +297,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
# ensure the response data is available
assert self.response is not None, "response is None"
assert self.json_response is not None, "json_response is None"
assert self.response_json is not None, "response_json is None"
assert self.response_string is not None, "response_string is None"
if self.context.dataset_guardrails:
@@ -292,12 +305,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
guardrails_execution_result = await get_guardrails_check_result(
self.context,
action=GuardrailAction.BLOCK,
json_response=self.json_response,
)
print(
"Here is the guardrails_execution_result in on_end in InstrumentedAnthropicResponse: ",
guardrails_execution_result,
flush=True,
response_json=self.response_json,
)
if guardrails_execution_result.get("errors", []):
guardrail_response_string = json.dumps(
@@ -313,7 +321,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.response_json,
guardrails_execution_result,
)
)
@@ -330,7 +338,7 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context, self.json_response, guardrails_execution_result
self.context, self.response_json, guardrails_execution_result
)
)
@@ -378,7 +386,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -440,12 +448,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
)
print(
"Here is the guardrails_execution_result in on_chunk in InstrumentedAnthropicStreamingResponse: ",
self.guardrails_execution_result,
flush=True,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
+12 -1
View File
@@ -290,7 +290,6 @@ async def stream_response(
async def event_generator():
async for chunk in response.instrumented_event_generator():
yield chunk
print("chunk", chunk)
return StreamingResponse(
event_generator(),
@@ -408,6 +407,18 @@ async def push_to_explorer(
guardrails_execution_result.get("errors", [])
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=response_json,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", [])
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
converted_requests = convert_request(context.request_json)
converted_responses = convert_response(response_json)
+27 -14
View File
@@ -152,7 +152,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -203,7 +203,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
action=GuardrailAction.BLOCK,
json_response=self.merged_response,
response_json=self.merged_response,
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
@@ -438,6 +438,19 @@ async def push_to_explorer(
not in FINISH_REASON_TO_PUSH_TRACE
):
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", [])
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
# Combine the messages from the request body and the choices from the OpenAI response
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in merged_response.get("choices", [])]
@@ -453,7 +466,7 @@ async def push_to_explorer(
async def get_guardrails_check_result(
context: RequestContext,
action: GuardrailAction,
json_response: dict[str, Any] | None = None,
response_json: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Get the guardrails check result"""
# Determine which guardrails to apply based on the action
@@ -466,8 +479,8 @@ async def get_guardrails_check_result(
return {}
messages = list(context.request_json.get("messages", []))
if json_response is not None:
messages += [choice["message"] for choice in json_response.get("choices", [])]
if response_json is not None:
messages += [choice["message"] for choice in response_json.get("choices", [])]
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
@@ -499,7 +512,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
# request outputs
self.response: Optional[httpx.Response] = None
self.json_response: Optional[dict[str, Any]] = None
self.response_json: Optional[dict[str, Any]] = None
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
@@ -545,7 +558,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
self.response = await self.client.send(self.open_ai_request)
try:
self.json_response = self.response.json()
self.response_json = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
@@ -554,10 +567,10 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=self.json_response.get("error", "Unknown error from OpenAI API"),
detail=self.response_json.get("error", "Unknown error from OpenAI API"),
)
response_string = json.dumps(self.json_response)
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
return Response(
@@ -577,8 +590,8 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
self.response is not None
), "on_end called before 'self.response' was available"
assert (
self.json_response is not None
), "on_end called before 'self.json_response' was available"
self.response_json is not None
), "on_end called before 'self.response_json' was available"
# extract original response status code
response_code = self.response.status_code
@@ -589,7 +602,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
action=GuardrailAction.BLOCK,
json_response=self.json_response,
response_json=self.response_json,
)
if self.guardrails_execution_result.get("errors", []):
response_string = json.dumps(
@@ -605,7 +618,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.response_json,
self.guardrails_execution_result,
)
)
@@ -624,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
asyncio.create_task(
push_to_explorer(
self.context,
self.json_response,
self.response_json,
# include any guardrailing errors if available
self.guardrails_execution_result,
)
@@ -195,14 +195,14 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
chat_response = client.models.generate_content(
model="gemini-2.0-flash",
contents="What is the capital of Spain?",
contents="What is the capital of Denmark?",
config={
"maxOutputTokens": 100,
},
)
# Verify the chat response
assert "MADRID" in chat_response.candidates[0].content.parts[0].text.upper()
assert "COPENHAGEN" in chat_response.candidates[0].content.parts[0].text.upper()
expected_assistant_message = chat_response.candidates[0].content.parts[0].text
# Wait for the trace to be saved
@@ -229,7 +229,7 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
assert trace["messages"] == [
{
"role": "user",
"content": [{"text": "What is the capital of Spain?", "type": "text"}],
"content": [{"text": "What is the capital of Denmark?", "type": "text"}],
},
{
"role": "assistant",