mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-07-02 01:05:29 +02:00
Anthropic related add
Anthropic related add
This commit is contained in:
+44
-32
@@ -53,7 +53,6 @@ async def anthropic_proxy(
|
||||
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
||||
}
|
||||
headers["accept-encoding"] = "identity"
|
||||
|
||||
if request.headers.get(
|
||||
"invariant-authorization"
|
||||
) is None and "|invariant-auth:" not in request.headers.get(HEADER_AUTHORIZATION):
|
||||
@@ -88,7 +87,7 @@ async def anthropic_proxy(
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to fetch response: {response.text}, got error{e}",
|
||||
detail=f"Failed to fetch response from Anthropic: {response.text}, got error{e}",
|
||||
)
|
||||
await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
@@ -124,7 +123,18 @@ async def handle_non_streaming_response(
|
||||
invariant_authorization: str,
|
||||
):
|
||||
"""Handles non-streaming Anthropic responses"""
|
||||
json_response = response.json()
|
||||
try:
|
||||
json_response = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
if json_response.get("stop_reason") in END_REASONS:
|
||||
await push_to_explorer(
|
||||
@@ -142,50 +152,52 @@ async def handle_streaming_response(
|
||||
) -> StreamingResponse:
|
||||
|
||||
formatted_invariant_response = []
|
||||
|
||||
response = await client.send(anthropic_request, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content)
|
||||
error_detail = error_json.get("error", "Unknown error from Anthropic")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to decode error response from Anthropic"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
async def event_generator() -> Any:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
anthropic_request.url,
|
||||
headers=anthropic_request.headers,
|
||||
content=anthropic_request.content,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
yield json.dumps(
|
||||
{"error": f"Failed to fetch response: {response.status_code}"}
|
||||
).encode()
|
||||
return
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk_decode = chunk.decode().strip()
|
||||
if not chunk_decode:
|
||||
continue
|
||||
|
||||
process_chunk_text(
|
||||
chunk,
|
||||
formatted_invariant_response
|
||||
)
|
||||
yield chunk
|
||||
|
||||
if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
formatted_invariant_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
process_chunk_text(
|
||||
chunk_decode,
|
||||
formatted_invariant_response
|
||||
)
|
||||
|
||||
if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
formatted_invariant_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def process_chunk_text(chunk, formatted_invariant_response):
|
||||
def process_chunk_text(chunk_decode, formatted_invariant_response):
|
||||
"""
|
||||
Process the chunk of text and update the formatted_invariant_response
|
||||
Example of chunk list can be find in:
|
||||
../../resources/streaming_chunk_text/anthropic.txt
|
||||
"""
|
||||
text_decode = chunk.decode().strip()
|
||||
for text_block in text_decode.split("\n\n"):
|
||||
# might be empty block
|
||||
|
||||
for text_block in chunk_decode.split("\n\n"):
|
||||
# might be empty block
|
||||
if len(text_block.split("\ndata:"))>1:
|
||||
text_data = text_block.split("\ndata:")[1]
|
||||
text_json = json.loads(text_data)
|
||||
|
||||
@@ -52,13 +52,13 @@ async def openai_proxy(
|
||||
# The invariant-authorization header contains the Invariant API Key
|
||||
# "invariant-authorization": "Bearer <Invariant API Key>"
|
||||
# The authorization header contains the OpenAI API Key
|
||||
# "authorization": "Bearer <OpenAI API Key>"
|
||||
# "authorization": "<OpenAI API Key>"
|
||||
#
|
||||
# For some clients, it is not possible to pass a custom header
|
||||
# In such cases, the Invariant API Key is passed as part of the
|
||||
# authorization header with the OpenAI API key.
|
||||
# The header in that case becomes:
|
||||
# "authorization": "Bearer <OpenAI API Key>|invariant-auth: <Invariant API Key>"
|
||||
# "authorization": "<OpenAI API Key>|invariant-auth: <Invariant API Key>"
|
||||
if request.headers.get(
|
||||
"invariant-authorization"
|
||||
) is None and "|invariant-auth:" not in request.headers.get("authorization"):
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
import anthropic
|
||||
from httpx import Client
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from util import * # needed for pytest fixtures
|
||||
|
||||
pytest_plugins = ("pytest_asyncio")
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_header(
|
||||
context, proxy_url, explorer_api_url
|
||||
):
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
dataset_name = "claude_header_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": anthropic_api_key + "|invariant-auth: <not needed for test>"}):
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(),
|
||||
base_url = f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
response_text = response.content[0].text
|
||||
assert "zurich" in response_text.lower()
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
|
||||
trace_id = traces[0]["id"]
|
||||
get_trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await get_trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
}
|
||||
]
|
||||
@@ -156,7 +156,7 @@ class WeatherAgent:
|
||||
return response
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_chat_completion_without_streaming(
|
||||
async def test_response_with_toolcall(
|
||||
context, explorer_api_url, proxy_url
|
||||
):
|
||||
"""Test the chat completion without streaming for the weather agent."""
|
||||
@@ -164,7 +164,7 @@ async def test_chat_completion_without_streaming(
|
||||
weather_agent = WeatherAgent(proxy_url)
|
||||
|
||||
queries = [
|
||||
"What's the weather like in Zurich city?",
|
||||
"What's the weather like in Zurich, Switzerland?",
|
||||
"Tell me the weather for New York",
|
||||
]
|
||||
cities = ["zurich", "new york"]
|
||||
@@ -213,14 +213,14 @@ async def test_chat_completion_without_streaming(
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_chat_completion_with_streaming(
|
||||
async def test_streaming_response_with_toolcall(
|
||||
context, explorer_api_url, proxy_url
|
||||
):
|
||||
"""Test the chat completion with streaming for the weather agent."""
|
||||
weather_agent = WeatherAgent(proxy_url)
|
||||
|
||||
queries = [
|
||||
"What's the weather like in Zurich city?",
|
||||
"What's the weather like in Zurich, Switzerland?",
|
||||
"Tell me the weather for New York",
|
||||
]
|
||||
cities = ["zurich", "new york"]
|
||||
|
||||
@@ -10,7 +10,7 @@ from util import * # needed for pytest fixtures
|
||||
|
||||
pytest_plugins = ("pytest_asyncio")
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_chat_completion_without_streaming(
|
||||
async def test_response_without_toolcall(
|
||||
context, explorer_api_url,proxy_url
|
||||
):
|
||||
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
@@ -25,7 +25,7 @@ async def test_chat_completion_without_streaming(
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
queries = [
|
||||
"Can you introduce Zurich city within 200 words?",
|
||||
"Can you introduce Zurich, Switzerland within 200 words?",
|
||||
"Tell me the history of New York within 100 words?",
|
||||
"How's the weather in London next week?"
|
||||
]
|
||||
@@ -71,7 +71,8 @@ async def test_chat_completion_without_streaming(
|
||||
async def test_streaming_response_without_toolcall(
|
||||
context,
|
||||
explorer_api_url,
|
||||
proxy_url):
|
||||
proxy_url
|
||||
):
|
||||
|
||||
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
|
||||
@@ -85,7 +86,7 @@ async def test_streaming_response_without_toolcall(
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
queries = [
|
||||
"Can you introduce Zurich city within 200 words?",
|
||||
"Can you introduce Zurich, Switzerland within 200 words?",
|
||||
"Tell me the history of New York within 100 words?",
|
||||
"How's the weather in London next week?"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user