mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-24 07:44:01 +02:00
Implement Anthropic proxy
Anthropic implement
This commit is contained in:
@@ -36,3 +36,6 @@ To integrate the Proxy with your AI agent, you’ll need to modify how your clie
|
||||
|
||||
### **🔹 Anthropic Integration**
|
||||
Coming Soon!
|
||||
|
||||
### Run
|
||||
./run.sh up
|
||||
@@ -2,4 +2,6 @@ fastapi==0.115.7
|
||||
httpx==0.28.1
|
||||
uvicorn==0.34.0
|
||||
invariant-sdk
|
||||
starlette-compress==1.4.0
|
||||
starlette-compress==1.4.0
|
||||
tavily-python
|
||||
anthropic
|
||||
+182
-1
@@ -1,5 +1,186 @@
|
||||
"""Proxy service to forward requests to the Anthropic APIs"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Header, HTTPException, Depends, Request
|
||||
import json
|
||||
import httpx
|
||||
from typing import Any
|
||||
from utils.explorer import push_trace
|
||||
# from .open_ai import push_to_explorer
|
||||
|
||||
proxy = APIRouter()
|
||||
|
||||
ALLOWED_ANTHROPIC_ENDPOINTS = {"v1/messages"}
|
||||
IGNORED_HEADERS = [
|
||||
"accept-encoding",
|
||||
"host",
|
||||
"invariant-authorization",
|
||||
"x-forwarded-for",
|
||||
"x-forwarded-host",
|
||||
"x-forwarded-port",
|
||||
"x-forwarded-proto",
|
||||
"x-forwarded-server",
|
||||
"x-real-ip",
|
||||
]
|
||||
|
||||
MISSING_INVARIANT_AUTH_HEADER = "Missing invariant-authorization header"
|
||||
MISSING_ANTHROPIC_AUTH_HEADER = "Missing athropic authorization header"
|
||||
NOT_SUPPORTED_ENDPOINT = "Not supported OpenAI endpoint"
|
||||
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
|
||||
END_REASONS = [
|
||||
"end_turn",
|
||||
"max_tokens",
|
||||
"stop_sequence"
|
||||
]
|
||||
|
||||
def validate_headers(
|
||||
invariant_authorization: str = Header(None), x_api_key: str = Header(None)
|
||||
):
|
||||
"""Require the invariant-authorization and authorization headers to be present"""
|
||||
if invariant_authorization is None:
|
||||
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_HEADER)
|
||||
if x_api_key is None:
|
||||
raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER)
|
||||
|
||||
@proxy.post(
|
||||
"/{dataset_name}/anthropic/{endpoint:path}",
|
||||
dependencies=[Depends(validate_headers)],
|
||||
)
|
||||
async def anthropic_proxy(
|
||||
dataset_name: str,
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
):
|
||||
"""Proxy calls to the Anthropic APIs"""
|
||||
if endpoint not in ALLOWED_ANTHROPIC_ENDPOINTS:
|
||||
raise HTTPException(status_code=404, detail=NOT_SUPPORTED_ENDPOINT)
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
||||
}
|
||||
|
||||
request_body = await request.body()
|
||||
|
||||
request_body_json = json.loads(request_body)
|
||||
|
||||
anthropic_url = f"https://api.anthropic.com/{endpoint}"
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
anthropic_request = client.build_request(
|
||||
"POST",
|
||||
anthropic_url,
|
||||
headers=headers,
|
||||
data=request_body
|
||||
)
|
||||
|
||||
invariant_authorization = request.headers.get("invariant-authorization")
|
||||
|
||||
async with client:
|
||||
response = await client.send(anthropic_request)
|
||||
await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def push_to_explorer(
|
||||
dataset_name: str,
|
||||
merged_response: dict[str, Any],
|
||||
request_body: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = request_body.get("messages", [])
|
||||
messages += [merged_response]
|
||||
|
||||
messages = anthropic_to_invariant_messages(messages)
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=invariant_authorization,
|
||||
)
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
response: httpx.Response,
|
||||
dataset_name: str,
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
):
|
||||
"""Handles non-streaming Anthropic responses"""
|
||||
json_response = response.json()
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
if json_response.get("stop_reason") in END_REASONS:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
json_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
def anthropic_to_invariant_messages(
|
||||
messages: list[dict], keep_empty_tool_response: bool = False
|
||||
) -> list[dict]:
|
||||
"""Converts a list of messages from the Anthropic API to the Invariant API format."""
|
||||
output = []
|
||||
role_mapping = {
|
||||
"system": lambda msg: {"role": "system", "content": msg["content"]},
|
||||
"user": lambda msg: handle_user_message(msg, keep_empty_tool_response),
|
||||
"assistant": lambda msg: handle_assistant_message(msg),
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
handler = role_mapping.get(message["role"])
|
||||
if handler:
|
||||
output.extend(handler(message))
|
||||
|
||||
return output
|
||||
|
||||
def handle_user_message(message, keep_empty_tool_response):
|
||||
output = []
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
for sub_message in content:
|
||||
if sub_message["type"] == "tool_result":
|
||||
if sub_message["content"]:
|
||||
output.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": sub_message["content"],
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
elif keep_empty_tool_response and any(sub_message.values()):
|
||||
output.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": {"is_error": True} if sub_message["is_error"] else {},
|
||||
"tool_id": sub_message["tool_use_id"],
|
||||
}
|
||||
)
|
||||
elif sub_message["type"] == "text":
|
||||
output.append({"role": "user", "content": sub_message["text"]})
|
||||
else:
|
||||
output.append({"role": "user", "content": content})
|
||||
return output
|
||||
|
||||
def handle_assistant_message(message):
|
||||
output = []
|
||||
for sub_message in message["content"]:
|
||||
if sub_message["type"] == "text":
|
||||
output.append({"role": "assistant", "content": sub_message.get("text")})
|
||||
elif sub_message["type"] == "tool_use":
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_id": sub_message.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sub_message.get("name"),
|
||||
"arguments": sub_message.get("input"),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -300,7 +300,6 @@ async def push_to_explorer(
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = request_body.get("messages", [])
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import anthropic
|
||||
from typing import Dict
|
||||
import os
|
||||
from tavily import TavilyClient
|
||||
import anthropic
|
||||
from httpx import Client
|
||||
import os
|
||||
# from invariant import testing
|
||||
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
import datetime
|
||||
|
||||
class WeatherAgent:
|
||||
def __init__(self, api_key: str):
|
||||
dataset_name = "claude_weather_agent_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {invariant_api_key}"
|
||||
},
|
||||
),
|
||||
base_url=f"http://localhost/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
self.get_weather_function = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
|
||||
# self.system_prompt = """You are an assistant that can perform weather searches using function calls.
|
||||
# When a user asks for weather information, respond with a JSON object specifying
|
||||
# the function name `get_weather` and the arguments latitude and longitude are needed."""
|
||||
|
||||
def get_response(self, user_query: str) -> Dict:
|
||||
"""
|
||||
Get the response from the agent for a given user query for weather.
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_query
|
||||
}
|
||||
]
|
||||
while True:
|
||||
response = self.client.messages.create(
|
||||
# system=self.system_prompt,
|
||||
tools = [self.get_weather_function],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
# If there's tool call, Extract the tool call parameters from the response
|
||||
if len(response.content) > 1 and response.content[1].type == "tool_use":
|
||||
tool_call_params = response.content[1].input
|
||||
tool_call_result = self.get_weather(tool_call_params["location"])
|
||||
tool_call_id = response.content[1].id
|
||||
messages.append({
|
||||
"role": response.role,
|
||||
"content": response.content
|
||||
}
|
||||
)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": tool_call_result
|
||||
}]
|
||||
})
|
||||
else:
|
||||
return response.content[0].text
|
||||
|
||||
def get_weather(self, location: str):
|
||||
"""Get the current weather in a given location using latitude and longitude."""
|
||||
query = f"What is the weather in {location}?"
|
||||
response = tavily_client.search(query)
|
||||
response_content = response["results"][0]["content"]
|
||||
return response["results"][0]["title"] + ":\n" + response_content
|
||||
|
||||
|
||||
# Initialize agent with your Anthropic API key
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
weather_agent = WeatherAgent(anthropic_api_key)
|
||||
|
||||
def test_proxy_response():
|
||||
# Example queries
|
||||
queries = [
|
||||
"What's the weather like in Zurich city?",
|
||||
"Tell me the forecast for New York",
|
||||
"How's the weather in London next week?"
|
||||
]
|
||||
cities = ["Zurich", "New York", "London"]
|
||||
# Process each query
|
||||
for index,query in enumerate(queries):
|
||||
response = weather_agent.get_response(query)
|
||||
assert response is not None
|
||||
assert cities[index] in response
|
||||
Reference in New Issue
Block a user