diff --git a/README.md b/README.md index 345db79..1cec817 100644 --- a/README.md +++ b/README.md @@ -36,3 +36,6 @@ To integrate the Proxy with your AI agent, you’ll need to modify how your clie ### **🔹 Anthropic Integration** Coming Soon! + +### Run +./run.sh up \ No newline at end of file diff --git a/proxy/requirements.txt b/proxy/requirements.txt index 6be2e1e..193dbfb 100644 --- a/proxy/requirements.txt +++ b/proxy/requirements.txt @@ -2,4 +2,6 @@ fastapi==0.115.7 httpx==0.28.1 uvicorn==0.34.0 invariant-sdk -starlette-compress==1.4.0 \ No newline at end of file +starlette-compress==1.4.0 +tavily-python +anthropic \ No newline at end of file diff --git a/proxy/routes/anthropic.py b/proxy/routes/anthropic.py index 6e1f61b..86706f9 100644 --- a/proxy/routes/anthropic.py +++ b/proxy/routes/anthropic.py @@ -1,5 +1,186 @@ """Proxy service to forward requests to the Anthropic APIs""" -from fastapi import APIRouter +from fastapi import APIRouter, Header, HTTPException, Depends, Request +import json +import httpx +from typing import Any +from utils.explorer import push_trace +# from .open_ai import push_to_explorer proxy = APIRouter() + +ALLOWED_ANTHROPIC_ENDPOINTS = {"v1/messages"} +IGNORED_HEADERS = [ + "accept-encoding", + "host", + "invariant-authorization", + "x-forwarded-for", + "x-forwarded-host", + "x-forwarded-port", + "x-forwarded-proto", + "x-forwarded-server", + "x-real-ip", +] + +MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header" +MISSING_ANTHROPIC_AUTH_HEADER = "Missing athropic authorization header" +NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint" +FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: " +END_REASONS = [ + "end_turn", + "max_tokens", + "stop_sequence" +] + +def validate_headers( + invariant_authorization: str = Header(None), x_api_key: str = Header(None) +): + """Require the invariant-authorization and authorization headers to be present""" + if invariant_authorization is None: + raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER) + if x_api_key is None: + raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER) + +@proxy.post( + "/{dataset_name}/anthropic/{endpoint:path}", + dependencies=[Depends(validate_headers)], +) +async def anthropic_proxy( + dataset_name: str, + endpoint: str, + request: Request, +): + """Proxy calls to the Anthropic APIs""" + if endpoint not in ALLOWED_ANTHROPIC_ENDPOINTS: + raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT) + headers = { + k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS + } + + request_body = await request.body() + + request_body_json = json.loads(request_body) + + anthropic_url = f"https://api.anthropic.com/{endpoint}" + client = httpx.AsyncClient() + + anthropic_request = client.build_request( + "POST", + anthropic_url, + headers=headers, + data=request_body + ) + + invariant_authorization = request.headers.get("invariant-authorization") + + async with client: + response = await client.send(anthropic_request) + await handle_non_streaming_response( + response, dataset_name, request_body_json, invariant_authorization + ) + return response.json() + +async def push_to_explorer( + dataset_name: str, + merged_response: dict[str, Any], + request_body: dict[str, Any], + invariant_authorization: str, +) -> None: + """Pushes the full trace to the Invariant Explorer""" + # Combine the messages from the request body and Anthropic response + messages = request_body.get("messages", []) + messages += [merged_response] + + messages = anthropic_to_invariant_messages(messages) + _ = await push_trace( + dataset_name=dataset_name, + messages=[messages], + invariant_authorization=invariant_authorization, + ) + +async def handle_non_streaming_response( + response: httpx.Response, + dataset_name: str, + request_body_json: dict[str, Any], + invariant_authorization: str, +): + """Handles non-streaming Anthropic responses""" + json_response = response.json() + # Only push the trace to explorer if the last message is an end turn message + if json_response.get("stop_reason") in END_REASONS: + await push_to_explorer( + dataset_name, + json_response, + request_body_json, + invariant_authorization, + ) + +def anthropic_to_invariant_messages( + messages: list[dict], keep_empty_tool_response: bool = False +) -> list[dict]: + """Converts a list of messages from the Anthropic API to the Invariant API format.""" + output = [] + role_mapping = { + "system": lambda msg: {"role": "system", "content": msg["content"]}, + "user": lambda msg: handle_user_message(msg, keep_empty_tool_response), + "assistant": lambda msg: handle_assistant_message(msg), + } + + for message in messages: + handler = role_mapping.get(message["role"]) + if handler: + output.extend(handler(message)) + + return output + +def handle_user_message(message, keep_empty_tool_response): + output = [] + content = message["content"] + if isinstance(content, list): + for sub_message in content: + if sub_message["type"] == "tool_result": + if sub_message["content"]: + output.append( + { + "role": "tool", + "content": sub_message["content"], + "tool_id": sub_message["tool_use_id"], + } + ) + elif keep_empty_tool_response and any(sub_message.values()): + output.append( + { + "role": "tool", + "content": {"is_error": True} if sub_message["is_error"] else {}, + "tool_id": sub_message["tool_use_id"], + } + ) + elif sub_message["type"] == "text": + output.append({"role": "user", "content": sub_message["text"]}) + else: + output.append({"role": "user", "content": content}) + return output + +def handle_assistant_message(message): + output = [] + for sub_message in message["content"]: + if sub_message["type"] == "text": + output.append({"role": "assistant", "content": sub_message.get("text")}) + elif sub_message["type"] == "tool_use": + output.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "tool_id": sub_message.get("id"), + "type": "function", + "function": { + "name": sub_message.get("name"), + "arguments": sub_message.get("input"), + }, + } + ], + } + ) + return output diff --git a/proxy/routes/open_ai.py b/proxy/routes/open_ai.py index 34f709d..74fbb44 100644 --- a/proxy/routes/open_ai.py +++ b/proxy/routes/open_ai.py @@ -300,7 +300,6 @@ async def push_to_explorer( # Combine the messages from the request body and the choices from the OpenAI response messages = request_body.get("messages", []) messages += [choice["message"] for choice in merged_response.get("choices", [])] - _ = await push_trace( dataset_name=dataset_name, messages=[messages], diff --git a/tests/anthropic/test_claude_weather_agent.py b/tests/anthropic/test_claude_weather_agent.py new file mode 100644 index 0000000..bb61c3c --- /dev/null +++ b/tests/anthropic/test_claude_weather_agent.py @@ -0,0 +1,112 @@ +import anthropic +from typing import Dict +import os +from tavily import TavilyClient +import anthropic +from httpx import Client +import os +# from invariant import testing +tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) +import datetime + +class WeatherAgent: + def __init__(self, api_key: str): + dataset_name = "claude_weather_agent_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S")) + invariant_api_key = os.environ.get("INVARIANT_API_KEY") + self.client = anthropic.Anthropic( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {invariant_api_key}" + }, + ), + base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic", + ) + self.get_weather_function = { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature, either \"celsius\" or \"fahrenheit\"" + } + }, + "required": ["location"] + } + } + + # self.system_prompt = """You are an assistant that can perform weather searches using function calls. + # When a user asks for weather information, respond with a JSON object specifying + # the function name `get_weather` and the arguments latitude and longitude are needed.""" + + def get_response(self, user_query: str) -> Dict: + """ + Get the response from the agent for a given user query for weather. + """ + messages = [ + { + "role": "user", + "content": user_query + } + ] + while True: + response = self.client.messages.create( + # system=self.system_prompt, + tools = [self.get_weather_function], + model="claude-3-5-sonnet-20241022", + max_tokens=1024, + messages=messages + ) + + # If there's tool call, Extract the tool call parameters from the response + if len(response.content) > 1 and response.content[1].type == "tool_use": + tool_call_params = response.content[1].input + tool_call_result = self.get_weather(tool_call_params["location"]) + tool_call_id = response.content[1].id + messages.append({ + "role": response.role, + "content": response.content + } + ) + messages.append({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": tool_call_result + }] + }) + else: + return response.content[0].text + + def get_weather(self, location: str): + """Get the current weather in a given location using latitude and longitude.""" + query = f"What is the weather in {location}?" + response = tavily_client.search(query) + response_content = response["results"][0]["content"] + return response["results"][0]["title"] + ":\n" + response_content + + +# Initialize agent with your Anthropic API key +anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") +weather_agent = WeatherAgent(anthropic_api_key) + +def test_proxy_response(): + # Example queries + queries = [ + "What's the weather like in Zurich city?", + "Tell me the forecast for New York", + "How's the weather in London next week?" + ] + cities = ["Zurich", "New York", "London"] + # Process each query + for index,query in enumerate(queries): + response = weather_agent.get_response(query) + assert response is not None + assert cities[index] in response \ No newline at end of file