mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-05 12:48:11 +02:00
add anthropic image handle
add anthropic image handle
This commit is contained in:
+31
-18
@@ -136,13 +136,12 @@ async def handle_non_streaming_response(
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
if json_response.get("stop_reason") in END_REASONS:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
json_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
json_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
async def handle_streaming_response(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -154,7 +153,6 @@ async def handle_streaming_response(
|
||||
formatted_invariant_response = []
|
||||
|
||||
response = await client.send(anthropic_request, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
@@ -169,7 +167,6 @@ async def handle_streaming_response(
|
||||
chunk_decode = chunk.decode().strip()
|
||||
if not chunk_decode:
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
process_chunk_text(
|
||||
@@ -177,13 +174,12 @@ async def handle_streaming_response(
|
||||
formatted_invariant_response
|
||||
)
|
||||
|
||||
if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
formatted_invariant_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
formatted_invariant_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
@@ -249,11 +245,11 @@ def anthropic_to_invariant_messages(
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def handle_user_message(message, keep_empty_tool_response):
|
||||
output = []
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
user_content = []
|
||||
for sub_message in content:
|
||||
if sub_message["type"] == "tool_result":
|
||||
if sub_message["content"]:
|
||||
@@ -275,7 +271,24 @@ def handle_user_message(message, keep_empty_tool_response):
|
||||
}
|
||||
)
|
||||
elif sub_message["type"] == "text":
|
||||
output.append({"role": "user", "content": sub_message["text"]})
|
||||
user_content.append({
|
||||
"type":"text",
|
||||
"text":sub_message["text"]
|
||||
})
|
||||
elif sub_message["type"] == "image":
|
||||
user_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:"+sub_message["source"]["media_type"]+";base64,"+sub_message["source"]["data"],
|
||||
},
|
||||
},
|
||||
|
||||
)
|
||||
if user_content:
|
||||
output.append({
|
||||
"role": "user",
|
||||
"content": user_content
|
||||
})
|
||||
else:
|
||||
output.append({"role": "user", "content": content})
|
||||
return output
|
||||
|
||||
@@ -5,7 +5,10 @@ import json
|
||||
import anthropic
|
||||
import pytest
|
||||
from httpx import Client
|
||||
import base64
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from util import * # needed for pytest fixtures
|
||||
@@ -45,11 +48,10 @@ class WeatherAgent:
|
||||
},
|
||||
}
|
||||
|
||||
def get_response(self, user_query: str) -> Dict:
|
||||
def get_response(self, messages: str) -> Dict:
|
||||
"""
|
||||
Get the response from the agent for a given user query for weather.
|
||||
"""
|
||||
messages = [{"role": "user", "content": user_query}]
|
||||
response_list = []
|
||||
while True:
|
||||
response = self.client.messages.create(
|
||||
@@ -80,8 +82,7 @@ class WeatherAgent:
|
||||
else:
|
||||
return response_list
|
||||
|
||||
def get_streaming_response(self, user_query: str) -> Dict:
|
||||
messages = [{"role": "user", "content": user_query}]
|
||||
def get_streaming_response(self, messages: str) -> Dict:
|
||||
response_list = []
|
||||
|
||||
def clean_quotes(text):
|
||||
@@ -163,52 +164,48 @@ async def test_response_with_toolcall(
|
||||
|
||||
weather_agent = WeatherAgent(proxy_url)
|
||||
|
||||
queries = [
|
||||
"What's the weather like in Zurich, Switzerland?",
|
||||
"Tell me the weather for New York",
|
||||
]
|
||||
cities = ["zurich", "new york"]
|
||||
query = "Tell me the weather for New York"
|
||||
|
||||
city = "new york"
|
||||
# Process each query
|
||||
responses = []
|
||||
for index, query in enumerate(queries):
|
||||
response = weather_agent.get_response(query)
|
||||
assert response is not None
|
||||
assert response[0].role == "assistant"
|
||||
assert response[0].stop_reason == "tool_use"
|
||||
assert response[0].content[0].type == "text"
|
||||
assert response[0].content[1].type == "tool_use"
|
||||
assert cities[index] in response[0].content[1].input["location"].lower()
|
||||
messages = [{"role": "user", "content": query}]
|
||||
response = weather_agent.get_response(messages)
|
||||
assert response is not None
|
||||
assert response[0].role == "assistant"
|
||||
assert response[0].stop_reason == "tool_use"
|
||||
assert response[0].content[0].type == "text"
|
||||
assert response[0].content[1].type == "tool_use"
|
||||
assert city in response[0].content[1].input["location"].lower()
|
||||
|
||||
assert response[1].role == "assistant"
|
||||
assert response[1].stop_reason == "end_turn"
|
||||
assert cities[index] in response[1].content[0].text.lower()
|
||||
responses.append(response)
|
||||
assert response[1].role == "assistant"
|
||||
assert response[1].stop_reason == "end_turn"
|
||||
assert city in response[1].content[0].text.lower()
|
||||
responses.append(response)
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
|
||||
for index,trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == queries[index]
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert cities[index] in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert cities[index] in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert cities[index] in trace_messages[4]["content"].lower()
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert city in trace_messages[4]["content"].lower()
|
||||
|
||||
|
||||
|
||||
@@ -219,46 +216,114 @@ async def test_streaming_response_with_toolcall(
|
||||
"""Test the chat completion with streaming for the weather agent."""
|
||||
weather_agent = WeatherAgent(proxy_url)
|
||||
|
||||
queries = [
|
||||
"What's the weather like in Zurich, Switzerland?",
|
||||
"Tell me the weather for New York",
|
||||
]
|
||||
cities = ["zurich", "new york"]
|
||||
query = "Tell me the weather for New York"
|
||||
city = "new york"
|
||||
|
||||
|
||||
for index, query in enumerate(queries):
|
||||
response = weather_agent.get_streaming_response(query)
|
||||
assert response is not None
|
||||
assert response[0][0].type == "text"
|
||||
assert response[0][1].type == "tool_use"
|
||||
assert response[0][1].name == "get_weather"
|
||||
assert cities[index] in response[0][1].input["location"].lower()
|
||||
|
||||
assert response[1][0].type == "text"
|
||||
assert cities[index] in response[1][0].text.lower()
|
||||
messages = [{"role": "user", "content": query}]
|
||||
response = weather_agent.get_streaming_response(messages)
|
||||
assert response is not None
|
||||
assert response[0][0].type == "text"
|
||||
assert response[0][1].type == "tool_use"
|
||||
assert response[0][1].name == "get_weather"
|
||||
assert city in response[0][1].input["location"].lower()
|
||||
|
||||
assert response[1][0].type == "text"
|
||||
assert city in response[1][0].text.lower()
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
for index,trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == queries[index]
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert cities[index] in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert cities[index] in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert cities[index] in trace_messages[4]["content"].lower()
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert city in trace_messages[4]["content"].lower()
|
||||
|
||||
|
||||
async def test_response_with_toolcall_with_image(
|
||||
context, explorer_api_url, proxy_url
|
||||
):
|
||||
weatherAgent = WeatherAgent(proxy_url)
|
||||
|
||||
image_path1 = Path(__file__).parent.parent / "images" / "new-york.jpeg"
|
||||
image_path2 = Path(__file__).parent.parent / "images" / "two-cats.png"
|
||||
|
||||
image1 = open(image_path1, "rb")
|
||||
image2 = open(image_path2, "rb")
|
||||
base64_image1 = base64.b64encode(image1.read()).decode("utf-8")
|
||||
base64_image2 = base64.b64encode(image2.read()).decode("utf-8")
|
||||
query = "get the weather in the city of these images"
|
||||
city = "new york"
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
|
||||
{
|
||||
"type": "text",
|
||||
"text": query,
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": base64_image1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image2,
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
response = weatherAgent.get_response(messages)
|
||||
assert response is not None
|
||||
assert response[0].role == "assistant"
|
||||
assert response[0].stop_reason == "tool_use"
|
||||
assert response[0].content[0].type == "text"
|
||||
assert response[0].content[1].type == "tool_use"
|
||||
assert city in response[0].content[1].input["location"].lower()
|
||||
|
||||
assert response[1].role == "assistant"
|
||||
assert response[1].stop_reason == "end_turn"
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weatherAgent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
Reference in New Issue
Block a user