Add some tests for the Anthropic conversion functions. Add support for system message in Anthropic. Rename some variables.

This commit is contained in:
Hemang
2025-03-20 23:51:08 +01:00
committed by Hemang Sarkar
parent 4a9930c30d
commit 3df9e73249
3 changed files with 109 additions and 5 deletions
+5 -1
View File
@@ -15,7 +15,11 @@ def convert_anthropic_to_invariant_message_format(
for message in messages:
handler = role_mapping.get(message["role"])
if handler:
output.extend(handler(message))
result = handler(message)
if isinstance(result, list):
output.extend(result)
else:
output.append(result)
return output
+17 -4
View File
@@ -100,12 +100,25 @@ def create_metadata(
return metadata
def combine_request_and_response_messages(
context: RequestContextData, json_response: dict[str, Any]
):
"""Combine the request and response messages"""
messages = []
if "system" in context.request_json:
messages.append(
{"role": "system", "content": context.request_json.get("system")}
)
messages.extend(context.request_json.get("messages", []))
messages.append(json_response)
return messages
async def get_guardrails_check_result(
context: RequestContextData, json_response: dict[str, Any]
) -> dict[str, Any]:
"""Get the guardrails check result"""
messages = list(context.request_json.get("messages", []))
messages.append(json_response)
messages = combine_request_and_response_messages(context, json_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
# Block on the guardrails check
@@ -129,8 +142,8 @@ async def push_to_explorer(
)
# Combine the messages from the request body and Anthropic response
messages = list(context.request_json.get("messages", []))
messages.append(merged_response)
messages = combine_request_and_response_messages(context, merged_response)
converted_messages = convert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=context.dataset_name,
@@ -0,0 +1,87 @@
"""Test the conversion from anthropic to invariant."""
import os
import sys
# Add root folder (parent) to sys.path
sys.path.append(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
)
from gateway.converters.anthropic_to_invariant import (
convert_anthropic_to_invariant_message_format,
)
def test_convert_messages_without_tool_call():
"""Test the conversion without tool calls."""
messages = [
{"role": "user", "content": "What is the capital of France?"},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "The capital of France is Paris. It is also the largest city in France and one of the most populous cities in Europe. Paris is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral.",
}
],
},
]
converted_messages = convert_anthropic_to_invariant_message_format(messages)
assert converted_messages == [
{"role": "user", "content": "What is the capital of France?"},
{
"role": "assistant",
"content": "The capital of France is Paris. It is also the largest city in France and one of the most populous cities in Europe. Paris is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral.",
},
]
def test_convert_messages_with_tool_call():
"""Test the conversion with tool call"""
messages = [
{"role": "system", "content": "This is the system message."},
{"role": "user", "content": "What is the capital of France?"},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "I'll help you find out the capital of France using the get_capital function.",
},
{
"type": "tool_use",
"id": "toolu_013btUg7dbaEq7NbPGzw4K9u",
"name": "get_capital",
"input": {"country_name": "France"},
},
],
},
]
converted_messages = convert_anthropic_to_invariant_message_format(messages)
assert converted_messages == [
{"role": "system", "content": "This is the system message."},
{"role": "user", "content": "What is the capital of France?"},
{
"role": "assistant",
"content": "I'll help you find out the capital of France using the get_capital function.",
},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "toolu_013btUg7dbaEq7NbPGzw4K9u",
"type": "function",
"function": {
"name": "get_capital",
"arguments": {"country_name": "France"},
},
}
],
},
]