mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-20 22:24:46 +02:00
Add guardrails checks for Anthropic.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Converts the request and response formats from Anthropic to Invariant API format."""
|
||||
|
||||
|
||||
def convert_anthropic_to_invariant_message_format(
|
||||
messages: list[dict], keep_empty_tool_response: bool = False
|
||||
) -> list[dict]:
|
||||
@@ -32,7 +33,7 @@ def handle_user_message(message, keep_empty_tool_response):
|
||||
{
|
||||
"role": "tool",
|
||||
"content": sub_message["content"],
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
"tool_call_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
elif keep_empty_tool_response and any(sub_message.values()):
|
||||
@@ -42,7 +43,7 @@ def handle_user_message(message, keep_empty_tool_response):
|
||||
"content": {"is_error": True}
|
||||
if sub_message["is_error"]
|
||||
else {},
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
"tool_call_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
elif sub_message["type"] == "text":
|
||||
@@ -69,27 +70,24 @@ def handle_user_message(message, keep_empty_tool_response):
|
||||
def handle_assistant_message(message):
|
||||
"""Handle the assistant message from the Anthropic API"""
|
||||
output = []
|
||||
if isinstance(message["content"], list):
|
||||
for sub_message in message["content"]:
|
||||
if sub_message["type"] == "text":
|
||||
output.append({"role": "assistant", "content": sub_message.get("text")})
|
||||
elif sub_message["type"] == "tool_use":
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_id": sub_message.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sub_message.get("name"),
|
||||
"arguments": sub_message.get("input"),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
output.append({"role": "assistant", "content": message["content"]})
|
||||
for sub_message in message["content"]:
|
||||
if sub_message["type"] == "text":
|
||||
output.append({"role": "assistant", "content": sub_message.get("text")})
|
||||
elif sub_message["type"] == "tool_use":
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": sub_message.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sub_message.get("name"),
|
||||
"arguments": sub_message.get("input"),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
+129
-56
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
@@ -12,13 +12,13 @@ from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from integrations.explorer import push_trace
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from converters.anthropic_to_invariant import (
|
||||
convert_anthropic_to_invariant_message_format,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from integrations.guardails import preload_guardrails
|
||||
from integrations.guardails import check_guardrails, preload_guardrails
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -27,8 +27,7 @@ FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
|
||||
END_REASONS = ["end_turn", "max_tokens", "stop_sequence"]
|
||||
|
||||
MESSAGE_START = "message_start"
|
||||
MESSGAE_DELTA = "message_delta"
|
||||
MESSAGE_STOP = "message_stop"
|
||||
MESSAGE_DELTA = "message_delta"
|
||||
CONTENT_BLOCK_START = "content_block_start"
|
||||
CONTENT_BLOCK_DELTA = "content_block_delta"
|
||||
CONTENT_BLOCK_STOP = "content_block_stop"
|
||||
@@ -101,21 +100,44 @@ def create_metadata(
|
||||
return metadata
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
context: RequestContextData, json_response: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get the guardrails check result"""
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages.append(json_response)
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await check_guardrails(
|
||||
messages=converted_messages,
|
||||
guardrails=context.config.guardrails,
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
)
|
||||
return guardrails_execution_result
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
context: RequestContextData,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = context.request_json.get("messages", [])
|
||||
messages += [merged_response]
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages.append(merged_response)
|
||||
converted_messages = convert_anthropic_to_invariant_message_format(messages)
|
||||
_ = await push_trace(
|
||||
dataset_name=context.dataset_name,
|
||||
messages=[converted_messages],
|
||||
invariant_authorization=context.invariant_authorization,
|
||||
metadata=[create_metadata(context, merged_response)],
|
||||
annotations=[annotations] if annotations else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -136,15 +158,35 @@ async def handle_non_streaming_response(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
# Don't block on the response from explorer
|
||||
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(json_response)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, json_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(push_to_explorer(context, json_response))
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, json_response, guardrails_execution_result)
|
||||
)
|
||||
|
||||
updated_headers = response.headers.copy()
|
||||
updated_headers.pop("Content-Length", None)
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
status_code=response.status_code,
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(updated_headers),
|
||||
)
|
||||
@@ -156,7 +198,7 @@ async def handle_streaming_response(
|
||||
anthropic_request: httpx.Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
merged_response = []
|
||||
merged_response = {}
|
||||
|
||||
response = await client.send(anthropic_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
@@ -170,65 +212,96 @@ async def handle_streaming_response(
|
||||
|
||||
async def event_generator() -> Any:
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk_decode = chunk.decode().strip()
|
||||
if not chunk_decode:
|
||||
decoded_chunk = chunk.decode().strip()
|
||||
if not decoded_chunk:
|
||||
continue
|
||||
process_chunk(decoded_chunk, merged_response)
|
||||
if (
|
||||
"event: message_stop" in decoded_chunk
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": {
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
yield f"event: error\ndata: {error_chunk}\n\n".encode()
|
||||
return
|
||||
yield chunk
|
||||
|
||||
process_chunk_text(chunk_decode, merged_response)
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on the response
|
||||
asyncio.create_task(push_to_explorer(context, merged_response[-1]))
|
||||
asyncio.create_task(push_to_explorer(context, merged_response))
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def process_chunk_text(chunk_decode, merged_response):
|
||||
def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None:
|
||||
"""
|
||||
Process the chunk of text and update the merged_response
|
||||
Example of chunk list can be find in:
|
||||
../../resources/streaming_chunk_text/anthropic.txt
|
||||
"""
|
||||
for text_block in chunk_decode.split("\n\n"):
|
||||
for text_block in chunk.split("\n\n"):
|
||||
# might be empty block
|
||||
if len(text_block.split("\ndata:")) > 1:
|
||||
text_data = text_block.split("\ndata:")[1]
|
||||
text_json = json.loads(text_data)
|
||||
update_merged_response(text_json, merged_response)
|
||||
event_text = text_block.split("\ndata:")[1]
|
||||
event = json.loads(event_text)
|
||||
update_merged_response(event, merged_response)
|
||||
|
||||
|
||||
def update_merged_response(text_json, merged_response):
|
||||
"""Update the formatted_invariant_response based on the text_json"""
|
||||
if text_json.get("type") == MESSAGE_START:
|
||||
message = text_json.get("message")
|
||||
merged_response.append(
|
||||
{
|
||||
"id": message.get("id"),
|
||||
"role": message.get("role"),
|
||||
"content": "",
|
||||
"model": message.get("model"),
|
||||
"stop_reason": message.get("stop_reason"),
|
||||
"stop_sequence": message.get("stop_sequence"),
|
||||
}
|
||||
)
|
||||
elif (
|
||||
text_json.get("type") == CONTENT_BLOCK_START
|
||||
and text_json.get("content_block").get("type") == "tool_use"
|
||||
):
|
||||
content_block = text_json.get("content_block")
|
||||
merged_response.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_id": content_block.get("id"),
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
elif text_json.get("type") == CONTENT_BLOCK_DELTA:
|
||||
if merged_response[-1]["role"] == "assistant":
|
||||
merged_response[-1]["content"] += text_json.get("delta").get("text")
|
||||
elif merged_response[-1]["role"] == "tool":
|
||||
merged_response[-1]["content"] += text_json.get("delta").get("partial_json")
|
||||
elif text_json.get("type") == MESSGAE_DELTA:
|
||||
merged_response[-1]["stop_reason"] = text_json.get("delta").get("stop_reason")
|
||||
def update_merged_response(
|
||||
event: dict[str, Any], merged_response: dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update the merged_response based on the event.
|
||||
|
||||
Each stream uses the following event flow:
|
||||
|
||||
1. message_start: contains a Message object with empty content.
|
||||
2. A series of content blocks, each of which have a content_block_start,
|
||||
one or more content_block_delta events, and a content_block_stop event.
|
||||
Each content block will have an index that corresponds to its index in the
|
||||
final Message content array.
|
||||
3. One or more message_delta events, indicating top-level changes to the final Message object.
|
||||
A final message_stop event.
|
||||
|
||||
"""
|
||||
if event.get("type") == MESSAGE_START:
|
||||
merged_response.update(**event.get("message"))
|
||||
elif event.get("type") == CONTENT_BLOCK_START:
|
||||
index = event.get("index")
|
||||
if index >= len(merged_response.get("content")):
|
||||
merged_response["content"].append(event.get("content_block"))
|
||||
if event.get("content_block").get("type") == "tool_use":
|
||||
merged_response.get("content")[-1]["input"] = ""
|
||||
elif event.get("type") == CONTENT_BLOCK_DELTA:
|
||||
index = event.get("index")
|
||||
delta = event.get("delta")
|
||||
if delta.get("type") == "text_delta":
|
||||
merged_response.get("content")[index]["text"] += delta.get("text")
|
||||
elif delta.get("type") == "input_json_delta":
|
||||
merged_response.get("content")[index]["input"] += delta.get("partial_json")
|
||||
elif event.get("type") == MESSAGE_DELTA:
|
||||
merged_response["usage"].update(**event.get("usage"))
|
||||
|
||||
@@ -352,7 +352,7 @@ async def push_to_explorer(
|
||||
):
|
||||
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = context.request_json.get("messages", [])
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
_ = await push_trace(
|
||||
dataset_name=context.dataset_name,
|
||||
|
||||
Reference in New Issue
Block a user