diff --git a/gateway/converters/anthropic_to_invariant.py b/gateway/converters/anthropic_to_invariant.py index 4f6274e..ffd5041 100644 --- a/gateway/converters/anthropic_to_invariant.py +++ b/gateway/converters/anthropic_to_invariant.py @@ -1,5 +1,6 @@ """Converts the request and response formats from Anthropic to Invariant API format.""" + def convert_anthropic_to_invariant_message_format( messages: list[dict], keep_empty_tool_response: bool = False ) -> list[dict]: @@ -32,7 +33,7 @@ def handle_user_message(message, keep_empty_tool_response): { "role": "tool", "content": sub_message["content"], - "tool_id": sub_message["tool_use_id"], + "tool_call_id": sub_message["tool_use_id"], } ) elif keep_empty_tool_response and any(sub_message.values()): @@ -42,7 +43,7 @@ def handle_user_message(message, keep_empty_tool_response): "content": {"is_error": True} if sub_message["is_error"] else {}, - "tool_id": sub_message["tool_use_id"], + "tool_call_id": sub_message["tool_use_id"], } ) elif sub_message["type"] == "text": @@ -69,27 +70,24 @@ def handle_user_message(message, keep_empty_tool_response): def handle_assistant_message(message): """Handle the assistant message from the Anthropic API""" output = [] - if isinstance(message["content"], list): - for sub_message in message["content"]: - if sub_message["type"] == "text": - output.append({"role": "assistant", "content": sub_message.get("text")}) - elif sub_message["type"] == "tool_use": - output.append( - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "tool_id": sub_message.get("id"), - "type": "function", - "function": { - "name": sub_message.get("name"), - "arguments": sub_message.get("input"), - }, - } - ], - } - ) - else: - output.append({"role": "assistant", "content": message["content"]}) + for sub_message in message["content"]: + if sub_message["type"] == "text": + output.append({"role": "assistant", "content": sub_message.get("text")}) + elif sub_message["type"] == "tool_use": + output.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": sub_message.get("id"), + "type": "function", + "function": { + "name": sub_message.get("name"), + "arguments": sub_message.get("input"), + }, + } + ], + } + ) return output diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 40838e4..ac8db7a 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any +from typing import Any, Optional import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -12,13 +12,13 @@ from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, ) -from integrations.explorer import push_trace +from integrations.explorer import create_annotations_from_guardrails_errors, push_trace from converters.anthropic_to_invariant import ( convert_anthropic_to_invariant_message_format, ) from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData -from integrations.guardails import preload_guardrails +from integrations.guardails import check_guardrails, preload_guardrails gateway = APIRouter() @@ -27,8 +27,7 @@ FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: " END_REASONS = ["end_turn", "max_tokens", "stop_sequence"] MESSAGE_START = "message_start" -MESSGAE_DELTA = "message_delta" -MESSAGE_STOP = "message_stop" +MESSAGE_DELTA = "message_delta" CONTENT_BLOCK_START = "content_block_start" CONTENT_BLOCK_DELTA = "content_block_delta" CONTENT_BLOCK_STOP = "content_block_stop" @@ -101,21 +100,44 @@ def create_metadata( return metadata +async def get_guardrails_check_result( + context: RequestContextData, json_response: dict[str, Any] +) -> dict[str, Any]: + """Get the guardrails check result""" + messages = list(context.request_json.get("messages", [])) + messages.append(json_response) + converted_messages = convert_anthropic_to_invariant_message_format(messages) + + # Block on the guardrails check + guardrails_execution_result = await check_guardrails( + messages=converted_messages, + guardrails=context.config.guardrails, + invariant_authorization=context.invariant_authorization, + ) + return guardrails_execution_result + + async def push_to_explorer( context: RequestContextData, merged_response: dict[str, Any], + guardrails_execution_result: Optional[dict] = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" - # Combine the messages from the request body and Anthropic response - messages = context.request_json.get("messages", []) - messages += [merged_response] + guardrails_execution_result = guardrails_execution_result or {} + annotations = create_annotations_from_guardrails_errors( + guardrails_execution_result.get("errors", []) + ) + # Combine the messages from the request body and Anthropic response + messages = list(context.request_json.get("messages", [])) + messages.append(merged_response) converted_messages = convert_anthropic_to_invariant_message_format(messages) _ = await push_trace( dataset_name=context.dataset_name, messages=[converted_messages], invariant_authorization=context.invariant_authorization, metadata=[create_metadata(context, merged_response)], + annotations=[annotations] if annotations else None, ) @@ -136,15 +158,35 @@ async def handle_non_streaming_response( status_code=response.status_code, detail=json_response.get("error", "Unknown error from Anthropic"), ) - # Only push the trace to explorer if the last message is an end turn message - # Don't block on the response from explorer + + guardrails_execution_result = {} + response_string = json.dumps(json_response) + response_code = response.status_code + + if context.config and context.config.guardrails: + # Block on the guardrails check + guardrails_execution_result = await get_guardrails_check_result( + context, json_response + ) + if guardrails_execution_result.get("errors", []): + response_string = json.dumps( + { + "error": "[Invariant] The response did not pass the guardrails", + "details": guardrails_execution_result, + } + ) + response_code = 400 if context.dataset_name: - asyncio.create_task(push_to_explorer(context, json_response)) + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer(context, json_response, guardrails_execution_result) + ) + updated_headers = response.headers.copy() updated_headers.pop("Content-Length", None) return Response( - content=json.dumps(json_response), - status_code=response.status_code, + content=response_string, + status_code=response_code, media_type="application/json", headers=dict(updated_headers), ) @@ -156,7 +198,7 @@ async def handle_streaming_response( anthropic_request: httpx.Request, ) -> StreamingResponse: """Handles streaming Anthropic responses""" - merged_response = [] + merged_response = {} response = await client.send(anthropic_request, stream=True) if response.status_code != 200: @@ -170,65 +212,96 @@ async def handle_streaming_response( async def event_generator() -> Any: async for chunk in response.aiter_bytes(): - chunk_decode = chunk.decode().strip() - if not chunk_decode: + decoded_chunk = chunk.decode().strip() + if not decoded_chunk: continue + process_chunk(decoded_chunk, merged_response) + if ( + "event: message_stop" in decoded_chunk + and context.config + and context.config.guardrails + ): + # Block on the guardrails check + guardrails_execution_result = await get_guardrails_check_result( + context, merged_response + ) + if guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "type": "error", + "error": { + "message": "[Invariant] The response did not pass the guardrails", + "details": guardrails_execution_result, + }, + } + ) + # Push annotated trace to the explorer - don't block on its response + if context.dataset_name: + asyncio.create_task( + push_to_explorer( + context, + merged_response, + guardrails_execution_result, + ) + ) + yield f"event: error\ndata: {error_chunk}\n\n".encode() + return yield chunk - process_chunk_text(chunk_decode, merged_response) if context.dataset_name: # Push to Explorer - don't block on the response - asyncio.create_task(push_to_explorer(context, merged_response[-1])) + asyncio.create_task(push_to_explorer(context, merged_response)) generator = event_generator() return StreamingResponse(generator, media_type="text/event-stream") -def process_chunk_text(chunk_decode, merged_response): +def process_chunk(chunk: str, merged_response: dict[str, Any]) -> None: """ Process the chunk of text and update the merged_response Example of chunk list can be find in: ../../resources/streaming_chunk_text/anthropic.txt """ - for text_block in chunk_decode.split("\n\n"): + for text_block in chunk.split("\n\n"): # might be empty block if len(text_block.split("\ndata:")) > 1: - text_data = text_block.split("\ndata:")[1] - text_json = json.loads(text_data) - update_merged_response(text_json, merged_response) + event_text = text_block.split("\ndata:")[1] + event = json.loads(event_text) + update_merged_response(event, merged_response) -def update_merged_response(text_json, merged_response): - """Update the formatted_invariant_response based on the text_json""" - if text_json.get("type") == MESSAGE_START: - message = text_json.get("message") - merged_response.append( - { - "id": message.get("id"), - "role": message.get("role"), - "content": "", - "model": message.get("model"), - "stop_reason": message.get("stop_reason"), - "stop_sequence": message.get("stop_sequence"), - } - ) - elif ( - text_json.get("type") == CONTENT_BLOCK_START - and text_json.get("content_block").get("type") == "tool_use" - ): - content_block = text_json.get("content_block") - merged_response.append( - { - "role": "tool", - "tool_id": content_block.get("id"), - "content": "", - } - ) - elif text_json.get("type") == CONTENT_BLOCK_DELTA: - if merged_response[-1]["role"] == "assistant": - merged_response[-1]["content"] += text_json.get("delta").get("text") - elif merged_response[-1]["role"] == "tool": - merged_response[-1]["content"] += text_json.get("delta").get("partial_json") - elif text_json.get("type") == MESSGAE_DELTA: - merged_response[-1]["stop_reason"] = text_json.get("delta").get("stop_reason") +def update_merged_response( + event: dict[str, Any], merged_response: dict[str, Any] +) -> None: + """ + Update the merged_response based on the event. + + Each stream uses the following event flow: + + 1. message_start: contains a Message object with empty content. + 2. A series of content blocks, each of which have a content_block_start, + one or more content_block_delta events, and a content_block_stop event. + Each content block will have an index that corresponds to its index in the + final Message content array. + 3. One or more message_delta events, indicating top-level changes to the final Message object. + A final message_stop event. + + """ + if event.get("type") == MESSAGE_START: + merged_response.update(**event.get("message")) + elif event.get("type") == CONTENT_BLOCK_START: + index = event.get("index") + if index >= len(merged_response.get("content")): + merged_response["content"].append(event.get("content_block")) + if event.get("content_block").get("type") == "tool_use": + merged_response.get("content")[-1]["input"] = "" + elif event.get("type") == CONTENT_BLOCK_DELTA: + index = event.get("index") + delta = event.get("delta") + if delta.get("type") == "text_delta": + merged_response.get("content")[index]["text"] += delta.get("text") + elif delta.get("type") == "input_json_delta": + merged_response.get("content")[index]["input"] += delta.get("partial_json") + elif event.get("type") == MESSAGE_DELTA: + merged_response["usage"].update(**event.get("usage")) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index e464ec3..897dbb8 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -352,7 +352,7 @@ async def push_to_explorer( ): annotations = create_annotations_from_guardrails_errors(guardrails_errors) # Combine the messages from the request body and the choices from the OpenAI response - messages = context.request_json.get("messages", []) + messages = list(context.request_json.get("messages", [])) messages += [choice["message"] for choice in merged_response.get("choices", [])] _ = await push_trace( dataset_name=context.dataset_name,