add anthropic image handle

add anthropic image handle
This commit is contained in:
zishan-wei
2025-02-24 10:17:24 +01:00
committed by GitHub
3 changed files with 172 additions and 94 deletions
+31 -18
View File
@@ -136,13 +136,12 @@ async def handle_non_streaming_response(
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
if json_response.get("stop_reason") in END_REASONS:
await push_to_explorer(
dataset_name,
json_response,
request_body_json,
invariant_authorization,
)
await push_to_explorer(
dataset_name,
json_response,
request_body_json,
invariant_authorization,
)
async def handle_streaming_response(
client: httpx.AsyncClient,
@@ -154,7 +153,6 @@ async def handle_streaming_response(
formatted_invariant_response = []
response = await client.send(anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
@@ -169,7 +167,6 @@ async def handle_streaming_response(
chunk_decode = chunk.decode().strip()
if not chunk_decode:
continue
yield chunk
process_chunk_text(
@@ -177,13 +174,12 @@ async def handle_streaming_response(
formatted_invariant_response
)
if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS:
await push_to_explorer(
dataset_name,
formatted_invariant_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
await push_to_explorer(
dataset_name,
formatted_invariant_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
generator = event_generator()
@@ -249,11 +245,11 @@ def anthropic_to_invariant_messages(
return output
def handle_user_message(message, keep_empty_tool_response):
output = []
content = message["content"]
if isinstance(content, list):
user_content = []
for sub_message in content:
if sub_message["type"] == "tool_result":
if sub_message["content"]:
@@ -275,7 +271,24 @@ def handle_user_message(message, keep_empty_tool_response):
}
)
elif sub_message["type"] == "text":
output.append({"role": "user", "content": sub_message["text"]})
user_content.append({
"type":"text",
"text":sub_message["text"]
})
elif sub_message["type"] == "image":
user_content.append({
"type": "image_url",
"image_url": {
"url": "data:"+sub_message["source"]["media_type"]+";base64,"+sub_message["source"]["data"],
},
},
)
if user_content:
output.append({
"role": "user",
"content": user_content
})
else:
output.append({"role": "user", "content": content})
return output
+141 -76
View File
@@ -5,7 +5,10 @@ import json
import anthropic
import pytest
from httpx import Client
import base64
import sys
from pathlib import Path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
@@ -45,11 +48,10 @@ class WeatherAgent:
},
}
def get_response(self, user_query: str) -> Dict:
def get_response(self, messages: str) -> Dict:
"""
Get the response from the agent for a given user query for weather.
"""
messages = [{"role": "user", "content": user_query}]
response_list = []
while True:
response = self.client.messages.create(
@@ -80,8 +82,7 @@ class WeatherAgent:
else:
return response_list
def get_streaming_response(self, user_query: str) -> Dict:
messages = [{"role": "user", "content": user_query}]
def get_streaming_response(self, messages: str) -> Dict:
response_list = []
def clean_quotes(text):
@@ -163,52 +164,48 @@ async def test_response_with_toolcall(
weather_agent = WeatherAgent(proxy_url)
queries = [
"What's the weather like in Zurich, Switzerland?",
"Tell me the weather for New York",
]
cities = ["zurich", "new york"]
query = "Tell me the weather for New York"
city = "new york"
# Process each query
responses = []
for index, query in enumerate(queries):
response = weather_agent.get_response(query)
assert response is not None
assert response[0].role == "assistant"
assert response[0].stop_reason == "tool_use"
assert response[0].content[0].type == "text"
assert response[0].content[1].type == "tool_use"
assert cities[index] in response[0].content[1].input["location"].lower()
messages = [{"role": "user", "content": query}]
response = weather_agent.get_response(messages)
assert response is not None
assert response[0].role == "assistant"
assert response[0].stop_reason == "tool_use"
assert response[0].content[0].type == "text"
assert response[0].content[1].type == "tool_use"
assert city in response[0].content[1].input["location"].lower()
assert response[1].role == "assistant"
assert response[1].stop_reason == "end_turn"
assert cities[index] in response[1].content[0].text.lower()
responses.append(response)
assert response[1].role == "assistant"
assert response[1].stop_reason == "end_turn"
assert city in response[1].content[0].text.lower()
responses.append(response)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == len(queries)
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
for index,trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == queries[index]
assert trace_messages[1]["role"] == "assistant"
assert cities[index] in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert cities[index] in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert cities[index] in trace_messages[4]["content"].lower()
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert city in trace_messages[4]["content"].lower()
@@ -219,46 +216,114 @@ async def test_streaming_response_with_toolcall(
"""Test the chat completion with streaming for the weather agent."""
weather_agent = WeatherAgent(proxy_url)
queries = [
"What's the weather like in Zurich, Switzerland?",
"Tell me the weather for New York",
]
cities = ["zurich", "new york"]
query = "Tell me the weather for New York"
city = "new york"
for index, query in enumerate(queries):
response = weather_agent.get_streaming_response(query)
assert response is not None
assert response[0][0].type == "text"
assert response[0][1].type == "tool_use"
assert response[0][1].name == "get_weather"
assert cities[index] in response[0][1].input["location"].lower()
assert response[1][0].type == "text"
assert cities[index] in response[1][0].text.lower()
messages = [{"role": "user", "content": query}]
response = weather_agent.get_streaming_response(messages)
assert response is not None
assert response[0][0].type == "text"
assert response[0][1].type == "tool_use"
assert response[0][1].name == "get_weather"
assert city in response[0][1].input["location"].lower()
assert response[1][0].type == "text"
assert city in response[1][0].text.lower()
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == len(queries)
for index,trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == queries[index]
assert trace_messages[1]["role"] == "assistant"
assert cities[index] in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert cities[index] in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert cities[index] in trace_messages[4]["content"].lower()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert city in trace_messages[4]["content"].lower()
async def test_response_with_toolcall_with_image(
context, explorer_api_url, proxy_url
):
weatherAgent = WeatherAgent(proxy_url)
image_path1 = Path(__file__).parent.parent / "images" / "new-york.jpeg"
image_path2 = Path(__file__).parent.parent / "images" / "two-cats.png"
image1 = open(image_path1, "rb")
image2 = open(image_path2, "rb")
base64_image1 = base64.b64encode(image1.read()).decode("utf-8")
base64_image2 = base64.b64encode(image2.read()).decode("utf-8")
query = "get the weather in the city of these images"
city = "new york"
messages = [
{
"role": "user", "content": [
{
"type": "text",
"text": query,
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": base64_image1,
}
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image2,
}
},
]
}
]
response = weatherAgent.get_response(messages)
assert response is not None
assert response[0].role == "assistant"
assert response[0].stop_reason == "tool_use"
assert response[0].content[0].type == "text"
assert response[0].content[1].type == "tool_use"
assert city in response[0].content[1].input["location"].lower()
assert response[1].role == "assistant"
assert response[1].stop_reason == "end_turn"
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weatherAgent.dataset_name}/traces"
)
traces = await traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB