From f2ffed91d3434b606318b671e184e33570fec207 Mon Sep 17 00:00:00 2001 From: Zishan Date: Thu, 6 Feb 2025 14:46:04 +0100 Subject: [PATCH] add test for anthropic agent --- proxy/routes/anthropic.py | 126 ++++++++++-------- ...her_agent => test_claude_weather_agent.py} | 51 ++++--- 2 files changed, 91 insertions(+), 86 deletions(-) rename proxy/tests/{claude_weather_agent => test_claude_weather_agent.py} (72%) diff --git a/proxy/routes/anthropic.py b/proxy/routes/anthropic.py index cbb5e2b..f209663 100644 --- a/proxy/routes/anthropic.py +++ b/proxy/routes/anthropic.py @@ -23,7 +23,7 @@ IGNORED_HEADERS = [ ] MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header" -MISSING_AUTH_HEADER = "Missing authorization header" +MISSING_ANTHROPIC_AUTH_HEADER = "Missing athropic authorization header" NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint" FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: " END_REASONS = [ @@ -33,13 +33,13 @@ END_REASONS = [ ] def validate_headers( - invariant_authorization: str = Header(None), authorization: str = Header(None) + invariant_authorization: str = Header(None), x_api_key: str = Header(None) ): """Require the invariant-authorization and authorization headers to be present""" if invariant_authorization is None: raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER) - # if authorization is None: - # raise HTTPException(status_code=400, detail=MISSING_AUTH_HEADER) + if x_api_key is None: + raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER) @proxy.post( "/{dataset_name}/anthropic/{endpoint:path}", @@ -53,11 +53,9 @@ async def anthropic_proxy( """Proxy calls to the Anthropic APIs""" if endpoint not in ALLOWED_ANTHROPIC_ENDPOINTS: raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT) - headers = { k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS } - # headers["accept-encoding"] = "identity" request_body = await request.body() @@ -89,14 +87,16 @@ async def push_to_explorer( invariant_authorization: str, ) -> None: """Pushes the full trace to the Invariant Explorer""" - # Combine the messages from the request body and the choices from the OpenAI response + # Combine the messages from the request body and Anthropic response messages = request_body.get("messages", []) if merged_response is not list: merged_response = [merged_response] messages += merged_response + + # Only push the trace to explorer if the last message is an end turn message if messages[-1].get("stop_reason") in END_REASONS: messages = anthropic_to_invariant_messages(messages) - response = await push_trace( + _ = await push_trace( dataset_name=dataset_name, messages=[messages], invariant_authorization=invariant_authorization, @@ -122,59 +122,67 @@ def anthropic_to_invariant_messages( ) -> list[dict]: """Converts a list of messages from the Anthropic API to the Invariant API format.""" output = [] + role_mapping = { + "system": lambda msg: {"role": "system", "content": msg["content"]}, + "user": lambda msg: handle_user_message(msg, keep_empty_tool_response), + "assistant": lambda msg: handle_assistant_message(msg), + } + for message in messages: - if message["role"] == "system": - output.append({"role": "system", "content": message["content"]}) - if message["role"] == "user": - if isinstance(message["content"], list): - for sub_message in message["content"]: - if sub_message["type"] == "tool_result": - if sub_message["content"]: - output.append( - { - "role": "tool", - "content": sub_message["content"], - "tool_id": sub_message["tool_use_id"], - } - ) - else: - if keep_empty_tool_response and any( - [sub_message[k] for k in sub_message] - ): - output.append( - { - "role": "tool", - "content": {"is_error": True} - if sub_message["is_error"] - else {}, - "tool_id": sub_message["tool_use_id"], - } - ) - elif sub_message["type"] == "text": - output.append({"role": "user", "content": sub_message["text"]}) - else: - output.append({"role": "user", "content": message["content"]}) - if message["role"] == "assistant": - for sub_message in message["content"]: - if sub_message["type"] == "text": - output.append( - {"role": "assistant", "content": sub_message.get("text")} - ) - if sub_message["type"] == "tool_use": + handler = role_mapping.get(message["role"]) + if handler: + output.extend(handler(message)) + + return output + +def handle_user_message(message, keep_empty_tool_response): + output = [] + content = message["content"] + if isinstance(content, list): + for sub_message in content: + if sub_message["type"] == "tool_result": + if sub_message["content"]: output.append( { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "tool_id": sub_message.get("id"), - "type": "function", - "function": { - "name": sub_message.get("name"), - "arguments": sub_message.get("input"), - }, - } - ], + "role": "tool", + "content": sub_message["content"], + "tool_id": sub_message["tool_use_id"], } ) - return output \ No newline at end of file + elif keep_empty_tool_response and any(sub_message.values()): + output.append( + { + "role": "tool", + "content": {"is_error": True} if sub_message["is_error"] else {}, + "tool_id": sub_message["tool_use_id"], + } + ) + elif sub_message["type"] == "text": + output.append({"role": "user", "content": sub_message["text"]}) + else: + output.append({"role": "user", "content": content}) + return output + +def handle_assistant_message(message): + output = [] + for sub_message in message["content"]: + if sub_message["type"] == "text": + output.append({"role": "assistant", "content": sub_message.get("text")}) + elif sub_message["type"] == "tool_use": + output.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "tool_id": sub_message.get("id"), + "type": "function", + "function": { + "name": sub_message.get("name"), + "arguments": sub_message.get("input"), + }, + } + ], + } + ) + return output diff --git a/proxy/tests/claude_weather_agent b/proxy/tests/test_claude_weather_agent.py similarity index 72% rename from proxy/tests/claude_weather_agent rename to proxy/tests/test_claude_weather_agent.py index ba26e1a..bacdd54 100644 --- a/proxy/tests/claude_weather_agent +++ b/proxy/tests/test_claude_weather_agent.py @@ -1,24 +1,27 @@ -from anthropic import Anthropic +import anthropic from typing import Dict, Optional, List import os from tavily import TavilyClient import anthropic from httpx import Client +import os +import pytest +# from invariant import testing tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) class WeatherAgent: def __init__(self, api_key: str): - # self.client = Anthropic(api_key=api_key) dataset_name = "claude_weather_agent_test7" - self.client = anthropic.Anthropic( - http_client=Client( - headers={ - "Invariant-Authorization": "Bearer inv-ff9cb8955c73e3d0afef86a5cef1ee773b1b349d9ed40886c78ef99b8d3dbc5a" - }, + invariant_api_key = os.environ.get("INVARIANT_API_KEY") + self.client = anthropic.Anthropic( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {invariant_api_key}" + }, ), base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic", ) - self.example_function = { + self.get_weather_function = { "name": "get_weather", "description": "Get the current weather in a given location", "input_schema": { @@ -38,14 +41,13 @@ class WeatherAgent: } } + # self.system_prompt = """You are an assistant that can perform weather searches using function calls. + # When a user asks for weather information, respond with a JSON object specifying + # the function name `get_weather` and the arguments latitude and longitude are needed.""" - self.system_prompt = """You are an assistant that can perform weather searches using function calls. - When a user asks for weather information, respond with a JSON object specifying - the function name `get_weather` and the arguments latitude and longitude are needed.""" - - def parse_weather_query(self, user_query: str) -> Dict: + def get_response(self, user_query: str) -> Dict: """ - Parse user query to extract weather-related parameters using Claude. + Get the response from the agent for a given user query for weather. """ messages = [ { @@ -56,7 +58,7 @@ class WeatherAgent: while True: response = self.client.messages.create( # system=self.system_prompt, - tools = [self.example_function], + tools = [self.get_weather_function], model="claude-3-5-sonnet-20241022", max_tokens=1024, messages=messages @@ -82,7 +84,6 @@ class WeatherAgent: "content": tool_call_result }] }) - print("messages:",messages,type(messages)) else: return response.content[0].text @@ -90,16 +91,15 @@ class WeatherAgent: """Get the current weather in a given location using latitude and longitude.""" query = f"What is the weather in {location}?" response = tavily_client.search(query) - # breakpoint() response_content = response["results"][0]["content"] return response["results"][0]["title"] + ":\n" + response_content -# Example usage -def main(): - # Initialize agent with your Anthropic API key - api_key = os.getenv("ANTHROPIC_API_KEY") - weather_agent = WeatherAgent(api_key) + +# Initialize agent with your Anthropic API key +anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") +weather_agent = WeatherAgent(anthropic_api_key) +def test_weather_agent(): # Example queries queries = [ "What's the weather like in Zurich city?", @@ -109,9 +109,6 @@ def main(): # Process each query for query in queries: - print(f"\nQuery: {query}") - response = weather_agent.parse_weather_query(query) + response = weather_agent.get_response(query) print(f"Response: {response}") - -if __name__ == "__main__": - main() \ No newline at end of file + assert response is not None \ No newline at end of file