Add guardrails checks for Anthropic.

This commit is contained in:
Hemang
2025-03-20 23:14:43 +01:00
committed by Hemang Sarkar
parent 781c6224d9
commit 4a9930c30d
3 changed files with 153 additions and 82 deletions
+23 -25
View File
@@ -1,5 +1,6 @@
"""Converts the request and response formats from Anthropic to Invariant API format."""
def convert_anthropic_to_invariant_message_format(
messages: list[dict], keep_empty_tool_response: bool = False
) -> list[dict]:
@@ -32,7 +33,7 @@ def handle_user_message(message, keep_empty_tool_response):
{
"role": "tool",
"content": sub_message["content"],
"tool_id": sub_message["tool_use_id"],
"tool_call_id": sub_message["tool_use_id"],
}
)
elif keep_empty_tool_response and any(sub_message.values()):
@@ -42,7 +43,7 @@ def handle_user_message(message, keep_empty_tool_response):
"content": {"is_error": True}
if sub_message["is_error"]
else {},
"tool_id": sub_message["tool_use_id"],
"tool_call_id": sub_message["tool_use_id"],
}
)
elif sub_message["type"] == "text":
@@ -69,27 +70,24 @@ def handle_user_message(message, keep_empty_tool_response):
def handle_assistant_message(message):
"""Handle the assistant message from the Anthropic API"""
output = []
if isinstance(message["content"], list):
for sub_message in message["content"]:
if sub_message["type"] == "text":
output.append({"role": "assistant", "content": sub_message.get("text")})
elif sub_message["type"] == "tool_use":
output.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"tool_id": sub_message.get("id"),
"type": "function",
"function": {
"name": sub_message.get("name"),
"arguments": sub_message.get("input"),
},
}
],
}
)
else:
output.append({"role": "assistant", "content": message["content"]})
for sub_message in message["content"]:
if sub_message["type"] == "text":
output.append({"role": "assistant", "content": sub_message.get("text")})
elif sub_message["type"] == "tool_use":
output.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": sub_message.get("id"),
"type": "function",
"function": {
"name": sub_message.get("name"),
"arguments": sub_message.get("input"),
},
}
],
}
)
return output
+129 -56
View File
@@ -2,7 +2,7 @@
import asyncio
import json
from typing import Any
from typing import Any, Optional
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -12,13 +12,13 @@ from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from integrations.explorer import push_trace
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from converters.anthropic_to_invariant import (
convert_anthropic_to_invariant_message_format,
)
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from integrations.guardails import preload_guardrails
from integrations.guardails import check_guardrails, preload_guardrails
gateway = APIRouter()
@@ -27,8 +27,7 @@ FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
END_REASONS = ["end_turn", "max_tokens", "stop_sequence"]
MESSAGE_START = "message_start"
MESSGAE_DELTA = "message_delta"
MESSAGE_STOP = "message_stop"
MESSAGE_DELTA = "message_delta"
CONTENT_BLOCK_START = "content_block_start"
CONTENT_BLOCK_DELTA = "content_block_delta"
CONTENT_BLOCK_STOP = "content_block_stop"
@@ -101,21 +100,44 @@ def create_metadata(
return metadata
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = list(context.request_json.get("messages", []))
messages.append(json_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
# Block on the guardrails check
guardrails_execution_result = await check_guardrails(
messages=converted_messages,
guardrails=context.config.guardrails,
invariant_authorization=context.invariant_authorization,
)
return guardrails_execution_result
async def push_to_explorer(
context: RequestContextData,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and Anthropic response
messages = context.request_json.get("messages", [])
messages += [merged_response]
guardrails_execution_result = guardrails_execution_result or {}
annotations = create_annotations_from_guardrails_errors(
guardrails_execution_result.get("errors", [])
)
# Combine the messages from the request body and Anthropic response
messages = list(context.request_json.get("messages", []))
messages.append(merged_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=context.dataset_name,
messages=[converted_messages],
invariant_authorization=context.invariant_authorization,
metadata=[create_metadata(context, merged_response)],
annotations=[annotations] if annotations else None,
)
@@ -136,15 +158,35 @@ async def handle_non_streaming_response(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
# Don't block on the response from explorer
guardrails_execution_result = {}
response_string = json.dumps(json_response)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, json_response
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
}
)
response_code = 400
if context.dataset_name:
asyncio.create_task(push_to_explorer(context, json_response))
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(context, json_response, guardrails_execution_result)
)
updated_headers = response.headers.copy()
updated_headers.pop("Content-Length", None)
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(updated_headers),
)
@@ -156,7 +198,7 @@ async def handle_streaming_response(
anthropic_request: httpx.Request,
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
merged_response = []
merged_response = {}
response = await client.send(anthropic_request, stream=True)
if response.status_code != 200:
@@ -170,65 +212,96 @@ async def handle_streaming_response(
async def event_generator() -> Any:
async for chunk in response.aiter_bytes():
chunk_decode = chunk.decode().strip()
if not chunk_decode:
decoded_chunk = chunk.decode().strip()
if not decoded_chunk:
continue
process_chunk(decoded_chunk, merged_response)
if (
"event: message_stop" in decoded_chunk
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"type": "error",
"error": {
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
},
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
yield f"event: error\ndata: {error_chunk}\n\n".encode()
return
yield chunk
process_chunk_text(chunk_decode, merged_response)
if context.dataset_name:
# Push to Explorer - don't block on the response
asyncio.create_task(push_to_explorer(context, merged_response[-1]))
asyncio.create_task(push_to_explorer(context, merged_response))
generator = event_generator()
return StreamingResponse(generator, media_type="text/event-stream")
def process_chunk_text(chunk_decode, merged_response):
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
"""
Process the chunk of text and update the merged_response
Example of chunk list can be find in:
../../resources/streaming_chunk_text/anthropic.txt
"""
for text_block in chunk_decode.split("\n\n"):
for text_block in chunk.split("\n\n"):
# might be empty block
if len(text_block.split("\ndata:")) > 1:
text_data = text_block.split("\ndata:")[1]
text_json = json.loads(text_data)
update_merged_response(text_json, merged_response)
event_text = text_block.split("\ndata:")[1]
event = json.loads(event_text)
update_merged_response(event, merged_response)
def update_merged_response(text_json, merged_response):
"""Update the formatted_invariant_response based on the text_json"""
if text_json.get("type") == MESSAGE_START:
message = text_json.get("message")
merged_response.append(
{
"id": message.get("id"),
"role": message.get("role"),
"content": "",
"model": message.get("model"),
"stop_reason": message.get("stop_reason"),
"stop_sequence": message.get("stop_sequence"),
}
)
elif (
text_json.get("type") == CONTENT_BLOCK_START
and text_json.get("content_block").get("type") == "tool_use"
):
content_block = text_json.get("content_block")
merged_response.append(
{
"role": "tool",
"tool_id": content_block.get("id"),
"content": "",
}
)
elif text_json.get("type") == CONTENT_BLOCK_DELTA:
if merged_response[-1]["role"] == "assistant":
merged_response[-1]["content"] += text_json.get("delta").get("text")
elif merged_response[-1]["role"] == "tool":
merged_response[-1]["content"] += text_json.get("delta").get("partial_json")
elif text_json.get("type") == MESSGAE_DELTA:
merged_response[-1]["stop_reason"] = text_json.get("delta").get("stop_reason")
def update_merged_response(
event: dict[str, Any], merged_response: dict[str, Any]
) -> None:
"""
Update the merged_response based on the event.
Each stream uses the following event flow:
1. message_start: contains a Message object with empty content.
2. A series of content blocks, each of which have a content_block_start,
one or more content_block_delta events, and a content_block_stop event.
Each content block will have an index that corresponds to its index in the
final Message content array.
3. One or more message_delta events, indicating top-level changes to the final Message object.
A final message_stop event.
"""
if event.get("type") == MESSAGE_START:
merged_response.update(**event.get("message"))
elif event.get("type") == CONTENT_BLOCK_START:
index = event.get("index")
if index >= len(merged_response.get("content")):
merged_response["content"].append(event.get("content_block"))
if event.get("content_block").get("type") == "tool_use":
merged_response.get("content")[-1]["input"] = ""
elif event.get("type") == CONTENT_BLOCK_DELTA:
index = event.get("index")
delta = event.get("delta")
if delta.get("type") == "text_delta":
merged_response.get("content")[index]["text"] += delta.get("text")
elif delta.get("type") == "input_json_delta":
merged_response.get("content")[index]["input"] += delta.get("partial_json")
elif event.get("type") == MESSAGE_DELTA:
merged_response["usage"].update(**event.get("usage"))
+1 -1
View File
@@ -352,7 +352,7 @@ async def push_to_explorer(
):
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
# Combine the messages from the request body and the choices from the OpenAI response
messages = context.request_json.get("messages", [])
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=context.dataset_name,