add test for anthropic agent

This commit is contained in:
Zishan
2025-02-06 14:46:04 +01:00
parent 2f2253220e
commit f2ffed91d3
2 changed files with 91 additions and 86 deletions
@@ -1,24 +1,27 @@
from anthropic import Anthropic
import anthropic
from typing import Dict, Optional, List
import os
from tavily import TavilyClient
import anthropic
from httpx import Client
import os
import pytest
# from invariant import testing
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
class WeatherAgent:
def __init__(self, api_key: str):
# self.client = Anthropic(api_key=api_key)
dataset_name = "claude_weather_agent_test7"
self.client = anthropic.Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer inv-ff9cb8955c73e3d0afef86a5cef1ee773b1b349d9ed40886c78ef99b8d3dbc5a"
},
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
self.client = anthropic.Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {invariant_api_key}"
},
),
base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic",
)
self.example_function = {
self.get_weather_function = {
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
@@ -38,14 +41,13 @@ class WeatherAgent:
}
}
# self.system_prompt = """You are an assistant that can perform weather searches using function calls.
# When a user asks for weather information, respond with a JSON object specifying
# the function name `get_weather` and the arguments latitude and longitude are needed."""
self.system_prompt = """You are an assistant that can perform weather searches using function calls.
When a user asks for weather information, respond with a JSON object specifying
the function name `get_weather` and the arguments latitude and longitude are needed."""
def parse_weather_query(self, user_query: str) -> Dict:
def get_response(self, user_query: str) -> Dict:
"""
Parse user query to extract weather-related parameters using Claude.
Get the response from the agent for a given user query for weather.
"""
messages = [
{
@@ -56,7 +58,7 @@ class WeatherAgent:
while True:
response = self.client.messages.create(
# system=self.system_prompt,
tools = [self.example_function],
tools = [self.get_weather_function],
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=messages
@@ -82,7 +84,6 @@ class WeatherAgent:
"content": tool_call_result
}]
})
print("messages:",messages,type(messages))
else:
return response.content[0].text
@@ -90,16 +91,15 @@ class WeatherAgent:
"""Get the current weather in a given location using latitude and longitude."""
query = f"What is the weather in {location}?"
response = tavily_client.search(query)
# breakpoint()
response_content = response["results"][0]["content"]
return response["results"][0]["title"] + ":\n" + response_content
# Example usage
def main():
# Initialize agent with your Anthropic API key
api_key = os.getenv("ANTHROPIC_API_KEY")
weather_agent = WeatherAgent(api_key)
# Initialize agent with your Anthropic API key
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
weather_agent = WeatherAgent(anthropic_api_key)
def test_weather_agent():
# Example queries
queries = [
"What's the weather like in Zurich city?",
@@ -109,9 +109,6 @@ def main():
# Process each query
for query in queries:
print(f"\nQuery: {query}")
response = weather_agent.parse_weather_query(query)
response = weather_agent.get_response(query)
print(f"Response: {response}")
if __name__ == "__main__":
main()
assert response is not None