diff --git a/proxy/routes/anthropic.py b/proxy/routes/anthropic.py index 1f92305..115c8fb 100644 --- a/proxy/routes/anthropic.py +++ b/proxy/routes/anthropic.py @@ -136,13 +136,12 @@ async def handle_non_streaming_response( detail=json_response.get("error", "Unknown error from Anthropic"), ) # Only push the trace to explorer if the last message is an end turn message - if json_response.get("stop_reason") in END_REASONS: - await push_to_explorer( - dataset_name, - json_response, - request_body_json, - invariant_authorization, - ) + await push_to_explorer( + dataset_name, + json_response, + request_body_json, + invariant_authorization, + ) async def handle_streaming_response( client: httpx.AsyncClient, @@ -154,7 +153,6 @@ async def handle_streaming_response( formatted_invariant_response = [] response = await client.send(anthropic_request, stream=True) - if response.status_code != 200: error_content = await response.aread() try: @@ -169,7 +167,6 @@ async def handle_streaming_response( chunk_decode = chunk.decode().strip() if not chunk_decode: continue - yield chunk process_chunk_text( @@ -177,13 +174,12 @@ async def handle_streaming_response( formatted_invariant_response ) - if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS: - await push_to_explorer( - dataset_name, - formatted_invariant_response[-1], - json.loads(anthropic_request.content), - invariant_authorization, - ) + await push_to_explorer( + dataset_name, + formatted_invariant_response[-1], + json.loads(anthropic_request.content), + invariant_authorization, + ) generator = event_generator() @@ -249,11 +245,11 @@ def anthropic_to_invariant_messages( return output - def handle_user_message(message, keep_empty_tool_response): output = [] content = message["content"] if isinstance(content, list): + user_content = [] for sub_message in content: if sub_message["type"] == "tool_result": if sub_message["content"]: @@ -275,7 +271,24 @@ def handle_user_message(message, keep_empty_tool_response): } ) elif sub_message["type"] == "text": - output.append({"role": "user", "content": sub_message["text"]}) + user_content.append({ + "type":"text", + "text":sub_message["text"] + }) + elif sub_message["type"] == "image": + user_content.append({ + "type": "image_url", + "image_url": { + "url": "data:"+sub_message["source"]["media_type"]+";base64,"+sub_message["source"]["data"], + }, + }, + + ) + if user_content: + output.append({ + "role": "user", + "content": user_content + }) else: output.append({"role": "user", "content": content}) return output diff --git a/tests/anthropic/test_anthropic_with_tool_call.py b/tests/anthropic/test_anthropic_with_tool_call.py index 3f94bec..87b0881 100644 --- a/tests/anthropic/test_anthropic_with_tool_call.py +++ b/tests/anthropic/test_anthropic_with_tool_call.py @@ -5,7 +5,10 @@ import json import anthropic import pytest from httpx import Client +import base64 import sys +from pathlib import Path + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from util import * # needed for pytest fixtures @@ -45,11 +48,10 @@ class WeatherAgent: }, } - def get_response(self, user_query: str) -> Dict: + def get_response(self, messages: str) -> Dict: """ Get the response from the agent for a given user query for weather. """ - messages = [{"role": "user", "content": user_query}] response_list = [] while True: response = self.client.messages.create( @@ -80,8 +82,7 @@ class WeatherAgent: else: return response_list - def get_streaming_response(self, user_query: str) -> Dict: - messages = [{"role": "user", "content": user_query}] + def get_streaming_response(self, messages: str) -> Dict: response_list = [] def clean_quotes(text): @@ -163,52 +164,48 @@ async def test_response_with_toolcall( weather_agent = WeatherAgent(proxy_url) - queries = [ - "What's the weather like in Zurich, Switzerland?", - "Tell me the weather for New York", - ] - cities = ["zurich", "new york"] + query = "Tell me the weather for New York" + + city = "new york" # Process each query responses = [] - for index, query in enumerate(queries): - response = weather_agent.get_response(query) - assert response is not None - assert response[0].role == "assistant" - assert response[0].stop_reason == "tool_use" - assert response[0].content[0].type == "text" - assert response[0].content[1].type == "tool_use" - assert cities[index] in response[0].content[1].input["location"].lower() + messages = [{"role": "user", "content": query}] + response = weather_agent.get_response(messages) + assert response is not None + assert response[0].role == "assistant" + assert response[0].stop_reason == "tool_use" + assert response[0].content[0].type == "text" + assert response[0].content[1].type == "tool_use" + assert city in response[0].content[1].input["location"].lower() - assert response[1].role == "assistant" - assert response[1].stop_reason == "end_turn" - assert cities[index] in response[1].content[0].text.lower() - responses.append(response) + assert response[1].role == "assistant" + assert response[1].stop_reason == "end_turn" + assert city in response[1].content[0].text.lower() + responses.append(response) traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" ) traces = await traces_response.json() - assert len(traces) == len(queries) + trace = traces[-1] + trace_id = trace["id"] + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + trace_messages = trace["messages"] - for index,trace in enumerate(traces): - trace_id = trace["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() - trace_messages = trace["messages"] - - assert trace_messages[0]["role"] == "user" - assert trace_messages[0]["content"] == queries[index] - assert trace_messages[1]["role"] == "assistant" - assert cities[index] in trace_messages[1]["content"].lower() - assert trace_messages[2]["role"] == "assistant" - assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" - assert cities[index] in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() - assert trace_messages[3]["role"] == "tool" - assert trace_messages[4]["role"] == "assistant" - assert cities[index] in trace_messages[4]["content"].lower() + assert trace_messages[0]["role"] == "user" + assert trace_messages[0]["content"] == query + assert trace_messages[1]["role"] == "assistant" + assert city in trace_messages[1]["content"].lower() + assert trace_messages[2]["role"] == "assistant" + assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" + assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() + assert trace_messages[3]["role"] == "tool" + assert trace_messages[4]["role"] == "assistant" + assert city in trace_messages[4]["content"].lower() @@ -219,46 +216,114 @@ async def test_streaming_response_with_toolcall( """Test the chat completion with streaming for the weather agent.""" weather_agent = WeatherAgent(proxy_url) - queries = [ - "What's the weather like in Zurich, Switzerland?", - "Tell me the weather for New York", - ] - cities = ["zurich", "new york"] + query = "Tell me the weather for New York" + city = "new york" - - for index, query in enumerate(queries): - response = weather_agent.get_streaming_response(query) - assert response is not None - assert response[0][0].type == "text" - assert response[0][1].type == "tool_use" - assert response[0][1].name == "get_weather" - assert cities[index] in response[0][1].input["location"].lower() - - assert response[1][0].type == "text" - assert cities[index] in response[1][0].text.lower() + messages = [{"role": "user", "content": query}] + response = weather_agent.get_streaming_response(messages) + assert response is not None + assert response[0][0].type == "text" + assert response[0][1].type == "tool_use" + assert response[0][1].name == "get_weather" + assert city in response[0][1].input["location"].lower() + + assert response[1][0].type == "text" + assert city in response[1][0].text.lower() traces_response = await context.request.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" ) traces = await traces_response.json() - assert len(traces) == len(queries) - for index,trace in enumerate(traces): - trace_id = trace["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() - trace_messages = trace["messages"] - assert trace_messages[0]["role"] == "user" - assert trace_messages[0]["content"] == queries[index] - assert trace_messages[1]["role"] == "assistant" - assert cities[index] in trace_messages[1]["content"].lower() - assert trace_messages[2]["role"] == "assistant" - assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" - assert cities[index] in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() - assert trace_messages[3]["role"] == "tool" - assert trace_messages[4]["role"] == "assistant" - assert cities[index] in trace_messages[4]["content"].lower() + trace = traces[-1] + trace_id = trace["id"] + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + trace_messages = trace["messages"] + assert trace_messages[0]["role"] == "user" + assert trace_messages[0]["content"] == query + assert trace_messages[1]["role"] == "assistant" + assert city in trace_messages[1]["content"].lower() + assert trace_messages[2]["role"] == "assistant" + assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" + assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() + assert trace_messages[3]["role"] == "tool" + assert trace_messages[4]["role"] == "assistant" + assert city in trace_messages[4]["content"].lower() + +async def test_response_with_toolcall_with_image( + context, explorer_api_url, proxy_url +): + weatherAgent = WeatherAgent(proxy_url) + + image_path1 = Path(__file__).parent.parent / "images" / "new-york.jpeg" + image_path2 = Path(__file__).parent.parent / "images" / "two-cats.png" + + image1 = open(image_path1, "rb") + image2 = open(image_path2, "rb") + base64_image1 = base64.b64encode(image1.read()).decode("utf-8") + base64_image2 = base64.b64encode(image2.read()).decode("utf-8") + query = "get the weather in the city of these images" + city = "new york" + messages = [ + { + "role": "user", "content": [ + + { + "type": "text", + "text": query, + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": base64_image1, + } + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": base64_image2, + } + }, + ] + } + ] + response = weatherAgent.get_response(messages) + assert response is not None + assert response[0].role == "assistant" + assert response[0].stop_reason == "tool_use" + assert response[0].content[0].type == "text" + assert response[0].content[1].type == "tool_use" + assert city in response[0].content[1].input["location"].lower() + + assert response[1].role == "assistant" + assert response[1].stop_reason == "end_turn" + + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weatherAgent.dataset_name}/traces" + ) + traces = await traces_response.json() + + trace = traces[-1] + trace_id = trace["id"] + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + trace_messages = trace["messages"] + assert trace_messages[0]["role"] == "user" + assert trace_messages[1]["role"] == "assistant" + assert city in trace_messages[1]["content"].lower() + assert trace_messages[2]["role"] == "assistant" + assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" + assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() + assert trace_messages[3]["role"] == "tool" + assert trace_messages[4]["role"] == "assistant" diff --git a/tests/images/new-york.jpeg b/tests/images/new-york.jpeg new file mode 100644 index 0000000..80bc4a6 Binary files /dev/null and b/tests/images/new-york.jpeg differ