Files
invariant-gateway/proxy/routes/anthropic.py
T

348 lines
13 KiB
Python

"""Proxy service to forward requests to the Anthropic APIs"""
import json
from typing import Any, Optional
import httpx
from common.config_manager import ProxyConfig, ProxyConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
INVARIANT_AUTHORIZATION_HEADER,
)
from utils.explorer import push_trace
proxy = APIRouter()
MISSING_INVARIANT_AUTH_API_KEY = "Missing invariant authorization header"
MISSING_ANTHROPIC_AUTH_HEADER = "Missing Anthropic authorization header"
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
END_REASONS = ["end_turn", "max_tokens", "stop_sequence"]
MESSAGE_START = "message_start"
MESSGAE_DELTA = "message_delta"
MESSAGE_STOP = "message_stop"
CONTENT_BLOCK_START = "content_block_start"
CONTENT_BLOCK_DELTA = "content_block_delta"
CONTENT_BLOCK_STOP = "content_block_stop"
ANTHROPIC_AUTHORIZATION_HEADER = "x-api-key"
def validate_headers(x_api_key: str = Header(None)):
"""Require the headers to be present"""
if x_api_key is None:
raise HTTPException(status_code=400, detail=MISSING_ANTHROPIC_AUTH_HEADER)
@proxy.post(
"/{dataset_name}/anthropic/v1/messages",
dependencies=[Depends(validate_headers)],
)
@proxy.post(
"/anthropic/v1/messages",
dependencies=[Depends(validate_headers)],
)
async def anthropic_v1_messages_proxy(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
config: ProxyConfig = Depends(ProxyConfigManager.get_config), # pylint: disable=unused-argument
):
"""Proxy calls to the Anthropic APIs"""
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
headers["accept-encoding"] = "identity"
# In case the user wants to push to Explorer, the request must contain the Invariant API Key
# The invariant-authorization header contains the Invariant API Key
# "invariant-authorization": "Bearer <Invariant API Key>"
# The x-api-key header contains the Anthropic API Key
# "x-api-key": "<Anthropic API Key>"
#
# For some clients, it is not possible to pass a custom header
# In such cases, the Invariant API Key is passed as part of the
# x-api-key header with the Anthropic API key.
# The header in that case becomes:
# "x-api-key": "<Anthropic API Key>|invariant-auth: <Invariant API Key>"
invariant_authorization = None
if dataset_name:
if request.headers.get(
INVARIANT_AUTHORIZATION_HEADER
) is None and "|invariant-auth:" not in request.headers.get(
ANTHROPIC_AUTHORIZATION_HEADER
):
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY)
if request.headers.get(INVARIANT_AUTHORIZATION_HEADER):
invariant_authorization = request.headers.get(
INVARIANT_AUTHORIZATION_HEADER
)
else:
header_value = request.headers.get(ANTHROPIC_AUTHORIZATION_HEADER)
api_keys = header_value.split("|invariant-auth: ")
invariant_authorization = f"Bearer {api_keys[1].strip()}"
# Update the authorization header to pass the Anthropic API Key
headers[ANTHROPIC_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}"
request_body = await request.body()
request_body_json = json.loads(request_body)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
anthropic_request = client.build_request(
"POST",
"https://api.anthropic.com/v1/messages",
headers=headers,
data=request_body,
)
if request_body_json.get("stream"):
return await handle_streaming_response(
client, anthropic_request, dataset_name, invariant_authorization
)
response = await client.send(anthropic_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
)
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and Anthropic response
messages = request_body.get("messages", [])
messages += [merged_response]
transformed_messages = convert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=dataset_name,
messages=[transformed_messages],
invariant_authorization=invariant_authorization,
)
async def handle_non_streaming_response(
response: httpx.Response,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
) -> Response:
"""Handles non-streaming Anthropic responses"""
try:
json_response = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
if dataset_name:
await push_to_explorer(
dataset_name,
json_response,
request_body_json,
invariant_authorization,
)
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
media_type="application/json",
headers=dict(response.headers),
)
async def handle_streaming_response(
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
dataset_name: Optional[str],
invariant_authorization: Optional[str],
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
merged_response = []
response = await client.send(anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content)
error_detail = error_json.get("error", "Unknown error from Anthropic")
except json.JSONDecodeError:
error_detail = {"error": "Failed to decode error response from Anthropic"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
async def event_generator() -> Any:
async for chunk in response.aiter_bytes():
chunk_decode = chunk.decode().strip()
if not chunk_decode:
continue
yield chunk
process_chunk_text(chunk_decode, merged_response)
if dataset_name:
await push_to_explorer(
dataset_name,
merged_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
generator = event_generator()
return StreamingResponse(generator, media_type="text/event-stream")
def process_chunk_text(chunk_decode, merged_response):
"""
Process the chunk of text and update the merged_response
Example of chunk list can be find in:
../../resources/streaming_chunk_text/anthropic.txt
"""
for text_block in chunk_decode.split("\n\n"):
# might be empty block
if len(text_block.split("\ndata:")) > 1:
text_data = text_block.split("\ndata:")[1]
text_json = json.loads(text_data)
update_merged_response(text_json, merged_response)
def update_merged_response(text_json, merged_response):
"""Update the formatted_invariant_response based on the text_json"""
if text_json.get("type") == MESSAGE_START:
message = text_json.get("message")
merged_response.append(
{
"id": message.get("id"),
"role": message.get("role"),
"content": "",
"model": message.get("model"),
"stop_reason": message.get("stop_reason"),
"stop_sequence": message.get("stop_sequence"),
}
)
elif (
text_json.get("type") == CONTENT_BLOCK_START
and text_json.get("content_block").get("type") == "tool_use"
):
content_block = text_json.get("content_block")
merged_response.append(
{
"role": "tool",
"tool_id": content_block.get("id"),
"content": "",
}
)
elif text_json.get("type") == CONTENT_BLOCK_DELTA:
if merged_response[-1]["role"] == "assistant":
merged_response[-1]["content"] += text_json.get("delta").get("text")
elif merged_response[-1]["role"] == "tool":
merged_response[-1]["content"] += text_json.get("delta").get("partial_json")
elif text_json.get("type") == MESSGAE_DELTA:
merged_response[-1]["stop_reason"] = text_json.get("delta").get("stop_reason")
def convert_anthropic_to_invariant_message_format(
messages: list[dict], keep_empty_tool_response: bool = False
) -> list[dict]:
"""Converts a list of messages from the Anthropic API to the Invariant API format."""
output = []
role_mapping = {
"system": lambda msg: {"role": "system", "content": msg["content"]},
"user": lambda msg: handle_user_message(msg, keep_empty_tool_response),
"assistant": lambda msg: handle_assistant_message(msg),
}
for message in messages:
handler = role_mapping.get(message["role"])
if handler:
output.extend(handler(message))
return output
def handle_user_message(message, keep_empty_tool_response):
"""Handle the user message from the Anthropic API"""
output = []
content = message["content"]
if isinstance(content, list):
user_content = []
for sub_message in content:
if sub_message["type"] == "tool_result":
if sub_message["content"]:
output.append(
{
"role": "tool",
"content": sub_message["content"],
"tool_id": sub_message["tool_use_id"],
}
)
elif keep_empty_tool_response and any(sub_message.values()):
output.append(
{
"role": "tool",
"content": {"is_error": True}
if sub_message["is_error"]
else {},
"tool_id": sub_message["tool_use_id"],
}
)
elif sub_message["type"] == "text":
user_content.append({"type": "text", "text": sub_message["text"]})
elif sub_message["type"] == "image":
user_content.append(
{
"type": "image_url",
"image_url": {
"url": "data:"
+ sub_message["source"]["media_type"]
+ ";base64,"
+ sub_message["source"]["data"],
},
},
)
if user_content:
output.append({"role": "user", "content": user_content})
else:
output.append({"role": "user", "content": content})
return output
def handle_assistant_message(message):
"""Handle the assistant message from the Anthropic API"""
output = []
if isinstance(message["content"], list):
for sub_message in message["content"]:
if sub_message["type"] == "text":
output.append({"role": "assistant", "content": sub_message.get("text")})
elif sub_message["type"] == "tool_use":
output.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"tool_id": sub_message.get("id"),
"type": "function",
"function": {
"name": sub_message.get("name"),
"arguments": sub_message.get("input"),
},
}
],
}
)
else:
output.append({"role": "assistant", "content": message["content"]})
return output