add test for anthropic agent

This commit is contained in:
Zishan
2025-02-06 14:46:04 +01:00
parent 2f2253220e
commit f2ffed91d3
2 changed files with 91 additions and 86 deletions

View File

@@ -23,7 +23,7 @@ IGNORED_HEADERS = [
]
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
MISSING_AUTH_HEADER = "Missing authorization header"
MISSING_ANTHROPIC_AUTH_HEADER = "Missing athropic authorization header"
NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint"
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
END_REASONS = [
@@ -33,13 +33,13 @@ END_REASONS = [
]
def validate_headers(
invariant_authorization: str = Header(None), authorization: str = Header(None)
invariant_authorization: str = Header(None), x_api_key: str = Header(None)
):
"""Require the invariant-authorization and authorization headers to be present"""
if invariant_authorization is None:
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER)
# if authorization is None:
# raise HTTPException(status_code=400, detail=MISSING_AUTH_HEADER)
if x_api_key is None:
raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER)
@proxy.post(
"/{dataset_name}/anthropic/{endpoint:path}",
@@ -53,11 +53,9 @@ async def anthropic_proxy(
"""Proxy calls to the Anthropic APIs"""
if endpoint not in ALLOWED_ANTHROPIC_ENDPOINTS:
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
# headers["accept-encoding"] = "identity"
request_body = await request.body()
@@ -89,14 +87,16 @@ async def push_to_explorer(
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and the choices from the OpenAI response
# Combine the messages from the request body and Anthropic response
messages = request_body.get("messages", [])
if merged_response is not list:
merged_response = [merged_response]
messages += merged_response
# Only push the trace to explorer if the last message is an end turn message
if messages[-1].get("stop_reason") in END_REASONS:
messages = anthropic_to_invariant_messages(messages)
response = await push_trace(
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=invariant_authorization,
@@ -122,59 +122,67 @@ def anthropic_to_invariant_messages(
) -> list[dict]:
"""Converts a list of messages from the Anthropic API to the Invariant API format."""
output = []
role_mapping = {
"system": lambda msg: {"role": "system", "content": msg["content"]},
"user": lambda msg: handle_user_message(msg, keep_empty_tool_response),
"assistant": lambda msg: handle_assistant_message(msg),
}
for message in messages:
if message["role"] == "system":
output.append({"role": "system", "content": message["content"]})
if message["role"] == "user":
if isinstance(message["content"], list):
for sub_message in message["content"]:
if sub_message["type"] == "tool_result":
if sub_message["content"]:
output.append(
{
"role": "tool",
"content": sub_message["content"],
"tool_id": sub_message["tool_use_id"],
}
)
else:
if keep_empty_tool_response and any(
[sub_message[k] for k in sub_message]
):
output.append(
{
"role": "tool",
"content": {"is_error": True}
if sub_message["is_error"]
else {},
"tool_id": sub_message["tool_use_id"],
}
)
elif sub_message["type"] == "text":
output.append({"role": "user", "content": sub_message["text"]})
else:
output.append({"role": "user", "content": message["content"]})
if message["role"] == "assistant":
for sub_message in message["content"]:
if sub_message["type"] == "text":
output.append(
{"role": "assistant", "content": sub_message.get("text")}
)
if sub_message["type"] == "tool_use":
handler = role_mapping.get(message["role"])
if handler:
output.extend(handler(message))
return output
def handle_user_message(message, keep_empty_tool_response):
output = []
content = message["content"]
if isinstance(content, list):
for sub_message in content:
if sub_message["type"] == "tool_result":
if sub_message["content"]:
output.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"tool_id": sub_message.get("id"),
"type": "function",
"function": {
"name": sub_message.get("name"),
"arguments": sub_message.get("input"),
},
}
],
"role": "tool",
"content": sub_message["content"],
"tool_id": sub_message["tool_use_id"],
}
)
return output
elif keep_empty_tool_response and any(sub_message.values()):
output.append(
{
"role": "tool",
"content": {"is_error": True} if sub_message["is_error"] else {},
"tool_id": sub_message["tool_use_id"],
}
)
elif sub_message["type"] == "text":
output.append({"role": "user", "content": sub_message["text"]})
else:
output.append({"role": "user", "content": content})
return output
def handle_assistant_message(message):
output = []
for sub_message in message["content"]:
if sub_message["type"] == "text":
output.append({"role": "assistant", "content": sub_message.get("text")})
elif sub_message["type"] == "tool_use":
output.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"tool_id": sub_message.get("id"),
"type": "function",
"function": {
"name": sub_message.get("name"),
"arguments": sub_message.get("input"),
},
}
],
}
)
return output

View File

@@ -1,24 +1,27 @@
from anthropic import Anthropic
import anthropic
from typing import Dict, Optional, List
import os
from tavily import TavilyClient
import anthropic
from httpx import Client
import os
import pytest
# from invariant import testing
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
class WeatherAgent:
def __init__(self, api_key: str):
# self.client = Anthropic(api_key=api_key)
dataset_name = "claude_weather_agent_test7"
self.client = anthropic.Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer inv-ff9cb8955c73e3d0afef86a5cef1ee773b1b349d9ed40886c78ef99b8d3dbc5a"
},
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
self.client = anthropic.Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {invariant_api_key}"
},
),
base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic",
)
self.example_function = {
self.get_weather_function = {
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
@@ -38,14 +41,13 @@ class WeatherAgent:
}
}
# self.system_prompt = """You are an assistant that can perform weather searches using function calls.
# When a user asks for weather information, respond with a JSON object specifying
# the function name `get_weather` and the arguments latitude and longitude are needed."""
self.system_prompt = """You are an assistant that can perform weather searches using function calls.
When a user asks for weather information, respond with a JSON object specifying
the function name `get_weather` and the arguments latitude and longitude are needed."""
def parse_weather_query(self, user_query: str) -> Dict:
def get_response(self, user_query: str) -> Dict:
"""
Parse user query to extract weather-related parameters using Claude.
Get the response from the agent for a given user query for weather.
"""
messages = [
{
@@ -56,7 +58,7 @@ class WeatherAgent:
while True:
response = self.client.messages.create(
# system=self.system_prompt,
tools = [self.example_function],
tools = [self.get_weather_function],
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=messages
@@ -82,7 +84,6 @@ class WeatherAgent:
"content": tool_call_result
}]
})
print("messages:",messages,type(messages))
else:
return response.content[0].text
@@ -90,16 +91,15 @@ class WeatherAgent:
"""Get the current weather in a given location using latitude and longitude."""
query = f"What is the weather in {location}?"
response = tavily_client.search(query)
# breakpoint()
response_content = response["results"][0]["content"]
return response["results"][0]["title"] + ":\n" + response_content
# Example usage
def main():
# Initialize agent with your Anthropic API key
api_key = os.getenv("ANTHROPIC_API_KEY")
weather_agent = WeatherAgent(api_key)
# Initialize agent with your Anthropic API key
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
weather_agent = WeatherAgent(anthropic_api_key)
def test_weather_agent():
# Example queries
queries = [
"What's the weather like in Zurich city?",
@@ -109,9 +109,6 @@ def main():
# Process each query
for query in queries:
print(f"\nQuery: {query}")
response = weather_agent.parse_weather_query(query)
response = weather_agent.get_response(query)
print(f"Response: {response}")
if __name__ == "__main__":
main()
assert response is not None