diff --git a/gateway/converters/anthropic_to_invariant.py b/gateway/converters/anthropic_to_invariant.py index ffd5041..3448b64 100644 --- a/gateway/converters/anthropic_to_invariant.py +++ b/gateway/converters/anthropic_to_invariant.py @@ -15,7 +15,11 @@ def convert_anthropic_to_invariant_message_format( for message in messages: handler = role_mapping.get(message["role"]) if handler: - output.extend(handler(message)) + result = handler(message) + if isinstance(result, list): + output.extend(result) + else: + output.append(result) return output diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index ac8db7a..01004c6 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -100,12 +100,25 @@ def create_metadata( return metadata +def combine_request_and_response_messages( + context: RequestContextData, json_response: dict[str, Any] +): + """Combine the request and response messages""" + messages = [] + if "system" in context.request_json: + messages.append( + {"role": "system", "content": context.request_json.get("system")} + ) + messages.extend(context.request_json.get("messages", [])) + messages.append(json_response) + return messages + + async def get_guardrails_check_result( context: RequestContextData, json_response: dict[str, Any] ) -> dict[str, Any]: """Get the guardrails check result""" - messages = list(context.request_json.get("messages", [])) - messages.append(json_response) + messages = combine_request_and_response_messages(context, json_response) converted_messages = convert_anthropic_to_invariant_message_format(messages) # Block on the guardrails check @@ -129,8 +142,8 @@ async def push_to_explorer( ) # Combine the messages from the request body and Anthropic response - messages = list(context.request_json.get("messages", [])) - messages.append(merged_response) + messages = combine_request_and_response_messages(context, merged_response) + converted_messages = convert_anthropic_to_invariant_message_format(messages) _ = await push_trace( dataset_name=context.dataset_name, diff --git a/tests/unit_tests/converters/test_anthropic_to_invariant.py b/tests/unit_tests/converters/test_anthropic_to_invariant.py new file mode 100644 index 0000000..718270a --- /dev/null +++ b/tests/unit_tests/converters/test_anthropic_to_invariant.py @@ -0,0 +1,87 @@ +"""Test the conversion from anthropic to invariant.""" + +import os +import sys + +# Add root folder (parent) to sys.path +sys.path.append( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) +) + +from gateway.converters.anthropic_to_invariant import ( + convert_anthropic_to_invariant_message_format, +) + + +def test_convert_messages_without_tool_call(): + """Test the conversion without tool calls.""" + messages = [ + {"role": "user", "content": "What is the capital of France?"}, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The capital of France is Paris. It is also the largest city in France and one of the most populous cities in Europe. Paris is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral.", + } + ], + }, + ] + + converted_messages = convert_anthropic_to_invariant_message_format(messages) + assert converted_messages == [ + {"role": "user", "content": "What is the capital of France?"}, + { + "role": "assistant", + "content": "The capital of France is Paris. It is also the largest city in France and one of the most populous cities in Europe. Paris is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral.", + }, + ] + + +def test_convert_messages_with_tool_call(): + """Test the conversion with tool call""" + messages = [ + {"role": "system", "content": "This is the system message."}, + {"role": "user", "content": "What is the capital of France?"}, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "I'll help you find out the capital of France using the get_capital function.", + }, + { + "type": "tool_use", + "id": "toolu_013btUg7dbaEq7NbPGzw4K9u", + "name": "get_capital", + "input": {"country_name": "France"}, + }, + ], + }, + ] + + converted_messages = convert_anthropic_to_invariant_message_format(messages) + assert converted_messages == [ + {"role": "system", "content": "This is the system message."}, + {"role": "user", "content": "What is the capital of France?"}, + { + "role": "assistant", + "content": "I'll help you find out the capital of France using the get_capital function.", + }, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "toolu_013btUg7dbaEq7NbPGzw4K9u", + "type": "function", + "function": { + "name": "get_capital", + "arguments": {"country_name": "France"}, + }, + } + ], + }, + ]