mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-03-07 02:00:52 +00:00
add test for anthropic agent
This commit is contained in:
@@ -23,7 +23,7 @@ IGNORED_HEADERS = [
|
||||
]
|
||||
|
||||
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
|
||||
MISSING_AUTH_HEADER = "Missing authorization header"
|
||||
MISSING_ANTHROPIC_AUTH_HEADER = "Missing athropic authorization header"
|
||||
NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint"
|
||||
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
|
||||
END_REASONS = [
|
||||
@@ -33,13 +33,13 @@ END_REASONS = [
|
||||
]
|
||||
|
||||
def validate_headers(
|
||||
invariant_authorization: str = Header(None), authorization: str = Header(None)
|
||||
invariant_authorization: str = Header(None), x_api_key: str = Header(None)
|
||||
):
|
||||
"""Require the invariant-authorization and authorization headers to be present"""
|
||||
if invariant_authorization is None:
|
||||
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER)
|
||||
# if authorization is None:
|
||||
# raise HTTPException(status_code=400, detail=MISSING_AUTH_HEADER)
|
||||
if x_api_key is None:
|
||||
raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER)
|
||||
|
||||
@proxy.post(
|
||||
"/{dataset_name}/anthropic/{endpoint:path}",
|
||||
@@ -53,11 +53,9 @@ async def anthropic_proxy(
|
||||
"""Proxy calls to the Anthropic APIs"""
|
||||
if endpoint not in ALLOWED_ANTHROPIC_ENDPOINTS:
|
||||
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
||||
}
|
||||
# headers["accept-encoding"] = "identity"
|
||||
|
||||
request_body = await request.body()
|
||||
|
||||
@@ -89,14 +87,16 @@ async def push_to_explorer(
|
||||
invariant_authorization: str,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = request_body.get("messages", [])
|
||||
if merged_response is not list:
|
||||
merged_response = [merged_response]
|
||||
messages += merged_response
|
||||
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
if messages[-1].get("stop_reason") in END_REASONS:
|
||||
messages = anthropic_to_invariant_messages(messages)
|
||||
response = await push_trace(
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=invariant_authorization,
|
||||
@@ -122,59 +122,67 @@ def anthropic_to_invariant_messages(
|
||||
) -> list[dict]:
|
||||
"""Converts a list of messages from the Anthropic API to the Invariant API format."""
|
||||
output = []
|
||||
role_mapping = {
|
||||
"system": lambda msg: {"role": "system", "content": msg["content"]},
|
||||
"user": lambda msg: handle_user_message(msg, keep_empty_tool_response),
|
||||
"assistant": lambda msg: handle_assistant_message(msg),
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
output.append({"role": "system", "content": message["content"]})
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], list):
|
||||
for sub_message in message["content"]:
|
||||
if sub_message["type"] == "tool_result":
|
||||
if sub_message["content"]:
|
||||
output.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": sub_message["content"],
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
if keep_empty_tool_response and any(
|
||||
[sub_message[k] for k in sub_message]
|
||||
):
|
||||
output.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": {"is_error": True}
|
||||
if sub_message["is_error"]
|
||||
else {},
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
elif sub_message["type"] == "text":
|
||||
output.append({"role": "user", "content": sub_message["text"]})
|
||||
else:
|
||||
output.append({"role": "user", "content": message["content"]})
|
||||
if message["role"] == "assistant":
|
||||
for sub_message in message["content"]:
|
||||
if sub_message["type"] == "text":
|
||||
output.append(
|
||||
{"role": "assistant", "content": sub_message.get("text")}
|
||||
)
|
||||
if sub_message["type"] == "tool_use":
|
||||
handler = role_mapping.get(message["role"])
|
||||
if handler:
|
||||
output.extend(handler(message))
|
||||
|
||||
return output
|
||||
|
||||
def handle_user_message(message, keep_empty_tool_response):
|
||||
output = []
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
for sub_message in content:
|
||||
if sub_message["type"] == "tool_result":
|
||||
if sub_message["content"]:
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_id": sub_message.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sub_message.get("name"),
|
||||
"arguments": sub_message.get("input"),
|
||||
},
|
||||
}
|
||||
],
|
||||
"role": "tool",
|
||||
"content": sub_message["content"],
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
return output
|
||||
elif keep_empty_tool_response and any(sub_message.values()):
|
||||
output.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": {"is_error": True} if sub_message["is_error"] else {},
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
elif sub_message["type"] == "text":
|
||||
output.append({"role": "user", "content": sub_message["text"]})
|
||||
else:
|
||||
output.append({"role": "user", "content": content})
|
||||
return output
|
||||
|
||||
def handle_assistant_message(message):
|
||||
output = []
|
||||
for sub_message in message["content"]:
|
||||
if sub_message["type"] == "text":
|
||||
output.append({"role": "assistant", "content": sub_message.get("text")})
|
||||
elif sub_message["type"] == "tool_use":
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_id": sub_message.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sub_message.get("name"),
|
||||
"arguments": sub_message.get("input"),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
from anthropic import Anthropic
|
||||
import anthropic
|
||||
from typing import Dict, Optional, List
|
||||
import os
|
||||
from tavily import TavilyClient
|
||||
import anthropic
|
||||
from httpx import Client
|
||||
import os
|
||||
import pytest
|
||||
# from invariant import testing
|
||||
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
|
||||
class WeatherAgent:
|
||||
def __init__(self, api_key: str):
|
||||
# self.client = Anthropic(api_key=api_key)
|
||||
dataset_name = "claude_weather_agent_test7"
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer inv-ff9cb8955c73e3d0afef86a5cef1ee773b1b349d9ed40886c78ef99b8d3dbc5a"
|
||||
},
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {invariant_api_key}"
|
||||
},
|
||||
),
|
||||
base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
self.example_function = {
|
||||
self.get_weather_function = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"input_schema": {
|
||||
@@ -38,14 +41,13 @@ class WeatherAgent:
|
||||
}
|
||||
}
|
||||
|
||||
# self.system_prompt = """You are an assistant that can perform weather searches using function calls.
|
||||
# When a user asks for weather information, respond with a JSON object specifying
|
||||
# the function name `get_weather` and the arguments latitude and longitude are needed."""
|
||||
|
||||
self.system_prompt = """You are an assistant that can perform weather searches using function calls.
|
||||
When a user asks for weather information, respond with a JSON object specifying
|
||||
the function name `get_weather` and the arguments latitude and longitude are needed."""
|
||||
|
||||
def parse_weather_query(self, user_query: str) -> Dict:
|
||||
def get_response(self, user_query: str) -> Dict:
|
||||
"""
|
||||
Parse user query to extract weather-related parameters using Claude.
|
||||
Get the response from the agent for a given user query for weather.
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
@@ -56,7 +58,7 @@ class WeatherAgent:
|
||||
while True:
|
||||
response = self.client.messages.create(
|
||||
# system=self.system_prompt,
|
||||
tools = [self.example_function],
|
||||
tools = [self.get_weather_function],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=messages
|
||||
@@ -82,7 +84,6 @@ class WeatherAgent:
|
||||
"content": tool_call_result
|
||||
}]
|
||||
})
|
||||
print("messages:",messages,type(messages))
|
||||
else:
|
||||
return response.content[0].text
|
||||
|
||||
@@ -90,16 +91,15 @@ class WeatherAgent:
|
||||
"""Get the current weather in a given location using latitude and longitude."""
|
||||
query = f"What is the weather in {location}?"
|
||||
response = tavily_client.search(query)
|
||||
# breakpoint()
|
||||
response_content = response["results"][0]["content"]
|
||||
return response["results"][0]["title"] + ":\n" + response_content
|
||||
|
||||
# Example usage
|
||||
def main():
|
||||
# Initialize agent with your Anthropic API key
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
weather_agent = WeatherAgent(api_key)
|
||||
|
||||
# Initialize agent with your Anthropic API key
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
weather_agent = WeatherAgent(anthropic_api_key)
|
||||
|
||||
def test_weather_agent():
|
||||
# Example queries
|
||||
queries = [
|
||||
"What's the weather like in Zurich city?",
|
||||
@@ -109,9 +109,6 @@ def main():
|
||||
|
||||
# Process each query
|
||||
for query in queries:
|
||||
print(f"\nQuery: {query}")
|
||||
response = weather_agent.parse_weather_query(query)
|
||||
response = weather_agent.get_response(query)
|
||||
print(f"Response: {response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
assert response is not None
|
||||
Reference in New Issue
Block a user