Anthropic related add

Anthropic related add
This commit is contained in:
zishan-wei
2025-02-20 17:24:05 +01:00
committed by GitHub
5 changed files with 117 additions and 42 deletions
+44 -32
View File
@@ -53,7 +53,6 @@ async def anthropic_proxy(
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
headers["accept-encoding"] = "identity"
if request.headers.get(
"invariant-authorization"
) is None and "|invariant-auth:" not in request.headers.get(HEADER_AUTHORIZATION):
@@ -88,7 +87,7 @@ async def anthropic_proxy(
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to fetch response: {response.text}, got error{e}",
detail=f"Failed to fetch response from Anthropic: {response.text}, got error{e}",
)
await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
@@ -124,7 +123,18 @@ async def handle_non_streaming_response(
invariant_authorization: str,
):
"""Handles non-streaming Anthropic responses"""
json_response = response.json()
try:
json_response = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail=f"Invalid JSON response received from Anthropic: {response.text}, got error{e}",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
if json_response.get("stop_reason") in END_REASONS:
await push_to_explorer(
@@ -142,50 +152,52 @@ async def handle_streaming_response(
) -> StreamingResponse:
formatted_invariant_response = []
response = await client.send(anthropic_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content)
error_detail = error_json.get("error", "Unknown error from Anthropic")
except json.JSONDecodeError:
error_detail = {"error": "Failed to decode error response from Anthropic"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
async def event_generator() -> Any:
async with client.stream(
"POST",
anthropic_request.url,
headers=anthropic_request.headers,
content=anthropic_request.content,
) as response:
if response.status_code != 200:
yield json.dumps(
{"error": f"Failed to fetch response: {response.status_code}"}
).encode()
return
async for chunk in response.aiter_bytes():
yield chunk
async for chunk in response.aiter_bytes():
chunk_decode = chunk.decode().strip()
if not chunk_decode:
continue
process_chunk_text(
chunk,
formatted_invariant_response
)
yield chunk
if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS:
await push_to_explorer(
dataset_name,
formatted_invariant_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
process_chunk_text(
chunk_decode,
formatted_invariant_response
)
if formatted_invariant_response and formatted_invariant_response[-1].get("stop_reason") in END_REASONS:
await push_to_explorer(
dataset_name,
formatted_invariant_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
generator = event_generator()
return StreamingResponse(generator, media_type="text/event-stream")
def process_chunk_text(chunk, formatted_invariant_response):
def process_chunk_text(chunk_decode, formatted_invariant_response):
"""
Process the chunk of text and update the formatted_invariant_response
Example of chunk list can be find in:
../../resources/streaming_chunk_text/anthropic.txt
"""
text_decode = chunk.decode().strip()
for text_block in text_decode.split("\n\n"):
# might be empty block
for text_block in chunk_decode.split("\n\n"):
# might be empty block
if len(text_block.split("\ndata:"))>1:
text_data = text_block.split("\ndata:")[1]
text_json = json.loads(text_data)
+2 -2
View File
@@ -52,13 +52,13 @@ async def openai_proxy(
# The invariant-authorization header contains the Invariant API Key
# "invariant-authorization": "Bearer <Invariant API Key>"
# The authorization header contains the OpenAI API Key
# "authorization": "Bearer <OpenAI API Key>"
# "authorization": "<OpenAI API Key>"
#
# For some clients, it is not possible to pass a custom header
# In such cases, the Invariant API Key is passed as part of the
# authorization header with the OpenAI API key.
# The header in that case becomes:
# "authorization": "Bearer <OpenAI API Key>|invariant-auth: <Invariant API Key>"
# "authorization": "<OpenAI API Key>|invariant-auth: <Invariant API Key>"
if request.headers.get(
"invariant-authorization"
) is None and "|invariant-auth:" not in request.headers.get("authorization"):
@@ -0,0 +1,62 @@
from unittest.mock import patch
import os
import anthropic
from httpx import Client
import datetime
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
pytest_plugins = ("pytest_asyncio")
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_header(
context, proxy_url, explorer_api_url
):
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
dataset_name = "claude_header_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": anthropic_api_key + "|invariant-auth: <not needed for test>"}):
client = anthropic.Anthropic(
http_client=Client(),
base_url = f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
)
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
}
]
)
assert response is not None
response_text = response.content[0].text
assert "zurich" in response_text.lower()
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
get_trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await get_trace_response.json()
assert trace["messages"] == [
{
"role": "user",
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
},
{
"role": "assistant",
"content": response_text
}
]
@@ -156,7 +156,7 @@ class WeatherAgent:
return response
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_chat_completion_without_streaming(
async def test_response_with_toolcall(
context, explorer_api_url, proxy_url
):
"""Test the chat completion without streaming for the weather agent."""
@@ -164,7 +164,7 @@ async def test_chat_completion_without_streaming(
weather_agent = WeatherAgent(proxy_url)
queries = [
"What's the weather like in Zurich city?",
"What's the weather like in Zurich, Switzerland?",
"Tell me the weather for New York",
]
cities = ["zurich", "new york"]
@@ -213,14 +213,14 @@ async def test_chat_completion_without_streaming(
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_chat_completion_with_streaming(
async def test_streaming_response_with_toolcall(
context, explorer_api_url, proxy_url
):
"""Test the chat completion with streaming for the weather agent."""
weather_agent = WeatherAgent(proxy_url)
queries = [
"What's the weather like in Zurich city?",
"What's the weather like in Zurich, Switzerland?",
"Tell me the weather for New York",
]
cities = ["zurich", "new york"]
@@ -10,7 +10,7 @@ from util import * # needed for pytest fixtures
pytest_plugins = ("pytest_asyncio")
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_chat_completion_without_streaming(
async def test_response_without_toolcall(
context, explorer_api_url,proxy_url
):
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
@@ -25,7 +25,7 @@ async def test_chat_completion_without_streaming(
cities = ["zurich", "new york", "london"]
queries = [
"Can you introduce Zurich city within 200 words?",
"Can you introduce Zurich, Switzerland within 200 words?",
"Tell me the history of New York within 100 words?",
"How's the weather in London next week?"
]
@@ -71,7 +71,8 @@ async def test_chat_completion_without_streaming(
async def test_streaming_response_without_toolcall(
context,
explorer_api_url,
proxy_url):
proxy_url
):
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
@@ -85,7 +86,7 @@ async def test_streaming_response_without_toolcall(
cities = ["zurich", "new york", "london"]
queries = [
"Can you introduce Zurich city within 200 words?",
"Can you introduce Zurich, Switzerland within 200 words?",
"Tell me the history of New York within 100 words?",
"How's the weather in London next week?"
]