Implement Anthropic proxy

Anthropic implement
This commit is contained in:
zishan-wei
2025-02-07 10:39:03 +01:00
committed by GitHub
5 changed files with 300 additions and 3 deletions
+3
View File
@@ -36,3 +36,6 @@ To integrate the Proxy with your AI agent, youll need to modify how your clie
### **🔹 Anthropic Integration**
Coming Soon!
### Run
./run.sh up
+3 -1
View File
@@ -2,4 +2,6 @@ fastapi==0.115.7
httpx==0.28.1
uvicorn==0.34.0
invariant-sdk
starlette-compress==1.4.0
starlette-compress==1.4.0
tavily-python
anthropic
+182 -1
View File
@@ -1,5 +1,186 @@
"""Proxy service to forward requests to the Anthropic APIs"""
from fastapi import APIRouter
from fastapi import APIRouter, Header, HTTPException, Depends, Request
import json
import httpx
from typing import Any
from utils.explorer import push_trace
# from .open_ai import push_to_explorer
proxy = APIRouter()
ALLOWED_ANTHROPIC_ENDPOINTS = {"v1/messages"}
IGNORED_HEADERS = [
"accept-encoding",
"host",
"invariant-authorization",
"x-forwarded-for",
"x-forwarded-host",
"x-forwarded-port",
"x-forwarded-proto",
"x-forwarded-server",
"x-real-ip",
]
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
MISSING_ANTHROPIC_AUTH_HEADER = "Missing athropic authorization header"
NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint"
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
END_REASONS = [
"end_turn",
"max_tokens",
"stop_sequence"
]
def validate_headers(
invariant_authorization: str = Header(None), x_api_key: str = Header(None)
):
"""Require the invariant-authorization and authorization headers to be present"""
if invariant_authorization is None:
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER)
if x_api_key is None:
raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER)
@proxy.post(
"/{dataset_name}/anthropic/{endpoint:path}",
dependencies=[Depends(validate_headers)],
)
async def anthropic_proxy(
dataset_name: str,
endpoint: str,
request: Request,
):
"""Proxy calls to the Anthropic APIs"""
if endpoint not in ALLOWED_ANTHROPIC_ENDPOINTS:
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
request_body = await request.body()
request_body_json = json.loads(request_body)
anthropic_url = f"https://api.anthropic.com/{endpoint}"
client = httpx.AsyncClient()
anthropic_request = client.build_request(
"POST",
anthropic_url,
headers=headers,
data=request_body
)
invariant_authorization = request.headers.get("invariant-authorization")
async with client:
response = await client.send(anthropic_request)
await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
)
return response.json()
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and Anthropic response
messages = request_body.get("messages", [])
messages += [merged_response]
messages = anthropic_to_invariant_messages(messages)
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=invariant_authorization,
)
async def handle_non_streaming_response(
response: httpx.Response,
dataset_name: str,
request_body_json: dict[str, Any],
invariant_authorization: str,
):
"""Handles non-streaming Anthropic responses"""
json_response = response.json()
# Only push the trace to explorer if the last message is an end turn message
if json_response.get("stop_reason") in END_REASONS:
await push_to_explorer(
dataset_name,
json_response,
request_body_json,
invariant_authorization,
)
def anthropic_to_invariant_messages(
messages: list[dict], keep_empty_tool_response: bool = False
) -> list[dict]:
"""Converts a list of messages from the Anthropic API to the Invariant API format."""
output = []
role_mapping = {
"system": lambda msg: {"role": "system", "content": msg["content"]},
"user": lambda msg: handle_user_message(msg, keep_empty_tool_response),
"assistant": lambda msg: handle_assistant_message(msg),
}
for message in messages:
handler = role_mapping.get(message["role"])
if handler:
output.extend(handler(message))
return output
def handle_user_message(message, keep_empty_tool_response):
output = []
content = message["content"]
if isinstance(content, list):
for sub_message in content:
if sub_message["type"] == "tool_result":
if sub_message["content"]:
output.append(
{
"role": "tool",
"content": sub_message["content"],
"tool_id": sub_message["tool_use_id"],
}
)
elif keep_empty_tool_response and any(sub_message.values()):
output.append(
{
"role": "tool",
"content": {"is_error": True} if sub_message["is_error"] else {},
"tool_id": sub_message["tool_use_id"],
}
)
elif sub_message["type"] == "text":
output.append({"role": "user", "content": sub_message["text"]})
else:
output.append({"role": "user", "content": content})
return output
def handle_assistant_message(message):
output = []
for sub_message in message["content"]:
if sub_message["type"] == "text":
output.append({"role": "assistant", "content": sub_message.get("text")})
elif sub_message["type"] == "tool_use":
output.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"tool_id": sub_message.get("id"),
"type": "function",
"function": {
"name": sub_message.get("name"),
"arguments": sub_message.get("input"),
},
}
],
}
)
return output
-1
View File
@@ -300,7 +300,6 @@ async def push_to_explorer(
# Combine the messages from the request body and the choices from the OpenAI response
messages = request_body.get("messages", [])
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
@@ -0,0 +1,112 @@
import anthropic
from typing import Dict
import os
from tavily import TavilyClient
import anthropic
from httpx import Client
import os
# from invariant import testing
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
import datetime
class WeatherAgent:
def __init__(self, api_key: str):
dataset_name = "claude_weather_agent_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
self.client = anthropic.Anthropic(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {invariant_api_key}"
},
),
base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic",
)
self.get_weather_function = {
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
}
},
"required": ["location"]
}
}
# self.system_prompt = """You are an assistant that can perform weather searches using function calls.
# When a user asks for weather information, respond with a JSON object specifying
# the function name `get_weather` and the arguments latitude and longitude are needed."""
def get_response(self, user_query: str) -> Dict:
"""
Get the response from the agent for a given user query for weather.
"""
messages = [
{
"role": "user",
"content": user_query
}
]
while True:
response = self.client.messages.create(
# system=self.system_prompt,
tools = [self.get_weather_function],
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=messages
)
# If there's tool call, Extract the tool call parameters from the response
if len(response.content) > 1 and response.content[1].type == "tool_use":
tool_call_params = response.content[1].input
tool_call_result = self.get_weather(tool_call_params["location"])
tool_call_id = response.content[1].id
messages.append({
"role": response.role,
"content": response.content
}
)
messages.append({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": tool_call_result
}]
})
else:
return response.content[0].text
def get_weather(self, location: str):
"""Get the current weather in a given location using latitude and longitude."""
query = f"What is the weather in {location}?"
response = tavily_client.search(query)
response_content = response["results"][0]["content"]
return response["results"][0]["title"] + ":\n" + response_content
# Initialize agent with your Anthropic API key
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
weather_agent = WeatherAgent(anthropic_api_key)
def test_proxy_response():
# Example queries
queries = [
"What's the weather like in Zurich city?",
"Tell me the forecast for New York",
"How's the weather in London next week?"
]
cities = ["Zurich", "New York", "London"]
# Process each query
for index,query in enumerate(queries):
response = weather_agent.get_response(query)
assert response is not None
assert cities[index] in response