mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-22 23:09:44 +02:00
add test for anthropic agent
This commit is contained in:
@@ -1,24 +1,27 @@
|
||||
from anthropic import Anthropic
|
||||
import anthropic
|
||||
from typing import Dict, Optional, List
|
||||
import os
|
||||
from tavily import TavilyClient
|
||||
import anthropic
|
||||
from httpx import Client
|
||||
import os
|
||||
import pytest
|
||||
# from invariant import testing
|
||||
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
class WeatherAgent:
|
||||
def __init__(self, api_key: str):
|
||||
# self.client = Anthropic(api_key=api_key)
|
||||
dataset_name = "claude_weather_agent_test7"
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer inv-ff9cb8955c73e3d0afef86a5cef1ee773b1b349d9ed40886c78ef99b8d3dbc5a"
|
||||
},
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {invariant_api_key}"
|
||||
},
|
||||
),
|
||||
base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
self.example_function = {
|
||||
self.get_weather_function = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"input_schema": {
|
||||
@@ -38,14 +41,13 @@ class WeatherAgent:
|
||||
}
|
||||
}
|
||||
|
||||
# self.system_prompt = """You are an assistant that can perform weather searches using function calls.
|
||||
# When a user asks for weather information, respond with a JSON object specifying
|
||||
# the function name `get_weather` and the arguments latitude and longitude are needed."""
|
||||
|
||||
self.system_prompt = """You are an assistant that can perform weather searches using function calls.
|
||||
When a user asks for weather information, respond with a JSON object specifying
|
||||
the function name `get_weather` and the arguments latitude and longitude are needed."""
|
||||
|
||||
def parse_weather_query(self, user_query: str) -> Dict:
|
||||
def get_response(self, user_query: str) -> Dict:
|
||||
"""
|
||||
Parse user query to extract weather-related parameters using Claude.
|
||||
Get the response from the agent for a given user query for weather.
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
@@ -56,7 +58,7 @@ class WeatherAgent:
|
||||
while True:
|
||||
response = self.client.messages.create(
|
||||
# system=self.system_prompt,
|
||||
tools = [self.example_function],
|
||||
tools = [self.get_weather_function],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=messages
|
||||
@@ -82,7 +84,6 @@ class WeatherAgent:
|
||||
"content": tool_call_result
|
||||
}]
|
||||
})
|
||||
print("messages:",messages,type(messages))
|
||||
else:
|
||||
return response.content[0].text
|
||||
|
||||
@@ -90,16 +91,15 @@ class WeatherAgent:
|
||||
"""Get the current weather in a given location using latitude and longitude."""
|
||||
query = f"What is the weather in {location}?"
|
||||
response = tavily_client.search(query)
|
||||
# breakpoint()
|
||||
response_content = response["results"][0]["content"]
|
||||
return response["results"][0]["title"] + ":\n" + response_content
|
||||
|
||||
# Example usage
|
||||
def main():
|
||||
# Initialize agent with your Anthropic API key
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
weather_agent = WeatherAgent(api_key)
|
||||
|
||||
# Initialize agent with your Anthropic API key
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
weather_agent = WeatherAgent(anthropic_api_key)
|
||||
|
||||
def test_weather_agent():
|
||||
# Example queries
|
||||
queries = [
|
||||
"What's the weather like in Zurich city?",
|
||||
@@ -109,9 +109,6 @@ def main():
|
||||
|
||||
# Process each query
|
||||
for query in queries:
|
||||
print(f"\nQuery: {query}")
|
||||
response = weather_agent.parse_weather_query(query)
|
||||
response = weather_agent.get_response(query)
|
||||
print(f"Response: {response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
assert response is not None
|
||||
Reference in New Issue
Block a user