Add API endpoints so that the Proxy can be used without pushing to Explorer.

This commit is contained in:
Hemang
2025-02-25 21:21:24 +01:00
committed by Hemang Sarkar
parent 6d6f4d62c7
commit 6afbcd3ea0
7 changed files with 513 additions and 417 deletions
+34 -26
View File
@@ -1,10 +1,10 @@
"""Proxy service to forward requests to the Anthropic APIs"""
import json
from typing import Any
from typing import Any, Optional
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
from utils.explorer import push_trace
@@ -36,9 +36,13 @@ def validate_headers(x_api_key: str = Header(None)):
"/{dataset_name}/anthropic/v1/messages",
dependencies=[Depends(validate_headers)],
)
@proxy.post(
"/anthropic/v1/messages",
dependencies=[Depends(validate_headers)],
)
async def anthropic_v1_messages_proxy(
dataset_name: str,
request: Request,
dataset_name: str = None,
):
"""Proxy calls to the Anthropic APIs"""
headers = {
@@ -77,17 +81,10 @@ async def anthropic_v1_messages_proxy(
client, anthropic_request, dataset_name, invariant_authorization
)
else:
try:
response = await client.send(anthropic_request)
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to fetch response from Anthropic: {response.text}, got error{e}",
)
await handle_non_streaming_response(
response = await client.send(anthropic_request)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
)
return response.json()
async def push_to_explorer(
@@ -116,7 +113,7 @@ async def handle_non_streaming_response(
dataset_name: str,
request_body_json: dict[str, Any],
invariant_authorization: str,
):
) -> Response:
"""Handles non-streaming Anthropic responses"""
try:
json_response = response.json()
@@ -131,20 +128,28 @@ async def handle_non_streaming_response(
detail=json_response.get("error", "Unknown error from Anthropic"),
)
# Only push the trace to explorer if the last message is an end turn message
await push_to_explorer(
dataset_name,
json_response,
request_body_json,
invariant_authorization,
if dataset_name:
await push_to_explorer(
dataset_name,
json_response,
request_body_json,
invariant_authorization,
)
return Response(
content=json.dumps(json_response),
status_code=response.status_code,
media_type="application/json",
headers=dict(response.headers),
)
async def handle_streaming_response(
client: httpx.AsyncClient,
anthropic_request: httpx.Request,
dataset_name: str,
dataset_name: Optional[str],
invariant_authorization: str,
) -> StreamingResponse:
"""Handles streaming Anthropic responses"""
formatted_invariant_response = []
response = await client.send(anthropic_request, stream=True)
@@ -165,13 +170,13 @@ async def handle_streaming_response(
yield chunk
process_chunk_text(chunk_decode, formatted_invariant_response)
await push_to_explorer(
dataset_name,
formatted_invariant_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
if dataset_name:
await push_to_explorer(
dataset_name,
formatted_invariant_response[-1],
json.loads(anthropic_request.content),
invariant_authorization,
)
generator = event_generator()
@@ -193,6 +198,7 @@ def process_chunk_text(chunk_decode, formatted_invariant_response):
def update_formatted_invariant_response(text_json, formatted_invariant_response):
"""Update the formatted_invariant_response based on the text_json"""
if text_json.get("type") == MESSAGE_START:
message = text_json.get("message")
formatted_invariant_response.append(
@@ -252,6 +258,7 @@ def anthropic_to_invariant_messages(
def handle_user_message(message, keep_empty_tool_response):
"""Handle the user message from the Anthropic API"""
output = []
content = message["content"]
if isinstance(content, list):
@@ -298,6 +305,7 @@ def handle_user_message(message, keep_empty_tool_response):
def handle_assistant_message(message):
"""Handle the assistant message from the Anthropic API"""
output = []
if isinstance(message["content"], list):
for sub_message in message["content"]:
+20 -14
View File
@@ -1,7 +1,7 @@
"""Proxy service to forward requests to the OpenAI APIs"""
import json
from typing import Any
from typing import Any, Optional
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
@@ -26,9 +26,13 @@ def validate_headers(authorization: str = Header(None)):
"/{dataset_name}/openai/chat/completions",
dependencies=[Depends(validate_headers)],
)
@proxy.post(
"/openai/chat/completions",
dependencies=[Depends(validate_headers)],
)
async def openai_chat_completions_proxy(
request: Request,
dataset_name: str,
dataset_name: str = None,
) -> Response:
"""Proxy calls to the OpenAI APIs"""
@@ -92,7 +96,7 @@ async def openai_chat_completions_proxy(
async def stream_response(
client: httpx.AsyncClient,
open_ai_request: httpx.Request,
dataset_name: str,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: str,
) -> Response:
@@ -150,12 +154,13 @@ async def stream_response(
)
# Send full merged response to the explorer
await push_to_explorer(
dataset_name,
merged_response,
request_body_json,
invariant_authorization,
)
if dataset_name:
await push_to_explorer(
dataset_name,
merged_response,
request_body_json,
invariant_authorization,
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -318,10 +323,10 @@ async def push_to_explorer(
async def handle_non_streaming_response(
response: httpx.Response,
dataset_name: str,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: str,
):
) -> Response:
"""Handles non-streaming OpenAI responses"""
try:
json_response = response.json()
@@ -335,9 +340,10 @@ async def handle_non_streaming_response(
status_code=response.status_code,
detail=json_response.get("error", "Unknown error from OpenAI API"),
)
await push_to_explorer(
dataset_name, json_response, request_body_json, invariant_authorization
)
if dataset_name:
await push_to_explorer(
dataset_name, json_response, request_body_json, invariant_authorization
)
return Response(
content=json.dumps(json_response),
@@ -1,28 +1,45 @@
from unittest.mock import patch
import os
import anthropic
from httpx import Client
import datetime
"""Test the Anthropic proxy with Invariant key in the ANTHROPIC_API_KEY."""
import pytest
import datetime
import os
import sys
from unittest.mock import patch
import anthropic
import pytest
from httpx import Client
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
pytest_plugins = ("pytest_asyncio")
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_header(
context, proxy_url, explorer_api_url
pytest_plugins = ("pytest_asyncio",)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_proxy_with_invariant_key_in_anthropic_key(
context, proxy_url, explorer_api_url, push_to_explorer
):
"""Test the Anthropic proxy with Invariant key in the Anthropic key"""
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
dataset_name = "claude_header_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": anthropic_api_key + "|invariant-auth: <not needed for test>"}):
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
with patch.dict(
os.environ,
{
"ANTHROPIC_API_KEY": anthropic_api_key
+ "|invariant-auth: <not needed for test>"
},
):
client = anthropic.Anthropic(
http_client=Client(),
base_url = f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
http_client=Client(),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/anthropic",
)
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
@@ -30,33 +47,31 @@ async def test_header(
messages=[
{
"role": "user",
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
"content": "Give me an introduction to Zurich, Switzerland within 200 words.",
}
]
],
)
assert response is not None
response_text = response.content[0].text
assert "zurich" in response_text.lower()
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
if push_to_explorer:
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
get_trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await get_trace_response.json()
assert trace["messages"] == [
{
"role": "user",
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
},
{
"role": "assistant",
"content": response_text
}
]
trace_id = traces[0]["id"]
get_trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await get_trace_response.json()
assert trace["messages"] == [
{
"role": "user",
"content": "Give me an introduction to Zurich, Switzerland within 200 words.",
},
{"role": "assistant", "content": response_text},
]
+184 -150
View File
@@ -1,13 +1,16 @@
"""Test the Anthropic messages API with tool call for the weather agent."""
import base64
import datetime
import os
from typing import Dict
import json
import os
import sys
from pathlib import Path
from typing import Dict, List
import anthropic
import pytest
from httpx import Client
import base64
import sys
from pathlib import Path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -17,16 +20,20 @@ pytest_plugins = ("pytest_asyncio",)
class WeatherAgent:
def __init__(self,proxy_url):
"""Weather agent to get the current weather in a given location."""
def __init__(self, proxy_url, push_to_explorer):
self.dataset_name = "claude_weather_agent_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
self.client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{proxy_url}/api/v1/proxy/{self.dataset_name}/anthropic",
base_url=f"{proxy_url}/api/v1/proxy/{self.dataset_name}/anthropic"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/anthropic",
)
self.get_weather_function = {
"name": "get_weather",
@@ -48,7 +55,7 @@ class WeatherAgent:
},
}
def get_response(self, messages: str) -> Dict:
def get_response(self, messages: List[Dict]) -> List[Dict]:
"""
Get the response from the agent for a given user query for weather.
"""
@@ -58,7 +65,7 @@ class WeatherAgent:
tools=[self.get_weather_function],
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=messages
messages=messages,
)
response_list.append(response)
# If there's tool call, Extract the tool call parameters from the response
@@ -81,18 +88,19 @@ class WeatherAgent:
)
else:
return response_list
def get_streaming_response(self, messages: str) -> Dict:
def get_streaming_response(self, messages: List[Dict]) -> List[Dict]:
"""Get streaming response from the agent for a given user query for weather."""
response_list = []
def clean_quotes(text):
# Convert \' to '
text = text.replace("\'", "'")
text = text.replace("'", "'")
# Convert \" to "
text = text.replace('\"', '"')
text = text.replace('"', '"')
text = text.replace("\n", " ")
return text
while True:
json_data = ""
content = []
@@ -108,26 +116,35 @@ class WeatherAgent:
current_block = event.content_block
current_text = ""
elif isinstance(event, anthropic.types.RawContentBlockDeltaEvent):
if hasattr(event.delta, 'text'):
if hasattr(event.delta, "text"):
# Accumulate text for TextBlock
current_text += clean_quotes(event.delta.text)
elif hasattr(event.delta, 'partial_json'):
elif hasattr(event.delta, "partial_json"):
# Accumulate JSON for ToolUseBlock
json_data += clean_quotes(event.delta.partial_json)
current_text += clean_quotes(event.delta.partial_json)
elif isinstance(event, anthropic.types.RawContentBlockStopEvent):
# Block is complete, add it to content
if current_block.type == 'text':
content.append(anthropic.types.TextBlock(citations=None, text=current_text, type="text"))
elif current_block.type == 'tool_use':
if current_block.type == "text":
content.append(
anthropic.types.ToolUseBlock(id=current_block.id,
input=json.loads(current_text),
name=current_block.name,
type="tool_use")
anthropic.types.TextBlock(
citations=None, text=current_text, type="text"
)
)
response_list.append(content)
if isinstance(event, anthropic.types.RawMessageStopEvent) and event.message.stop_reason == "tool_use":
elif current_block.type == "tool_use":
content.append(
anthropic.types.ToolUseBlock(
id=current_block.id,
input=json.loads(current_text),
name=current_block.name,
type="tool_use",
)
)
response_list.append(content)
if (
isinstance(event, anthropic.types.RawMessageStopEvent)
and event.message.stop_reason == "tool_use"
):
tool_call_params = json.loads(json_data)
tool_call_result = self.get_weather(tool_call_params["location"])
messages.append({"role": "assistant", "content": content})
@@ -148,21 +165,25 @@ class WeatherAgent:
def get_weather(self, location: str):
"""Get the current weather in a given location using latitude and longitude."""
response = f'''Weather in {location}:
response = f"""Weather in {location}:
Good morning! Expect overcast skies with intermittent showers throughout the day.
Temperatures will range from a cool 15°C in the early hours to around 19°C by mid-afternoon.
Light winds from the northeast at about 10 km/h will keep conditions mild.
It might be a good idea to carry an umbrella if youre heading out. Stay dry and have a great day!
'''
"""
return response
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_response_with_toolcall(
context, explorer_api_url, proxy_url
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_response_with_tool_call(
context, explorer_api_url, proxy_url, push_to_explorer
):
"""Test the chat completion without streaming for the weather agent."""
weather_agent = WeatherAgent(proxy_url)
weather_agent = WeatherAgent(proxy_url, push_to_explorer)
query = "Tell me the weather for New York"
@@ -183,38 +204,46 @@ async def test_response_with_toolcall(
assert city in response[1].content[0].text.lower()
responses.append(response)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
if push_to_explorer:
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert city in trace_messages[4]["content"].lower()
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert (
city
in trace_messages[2]["tool_calls"][0]["function"]["arguments"][
"location"
].lower()
)
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert city in trace_messages[4]["content"].lower()
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_streaming_response_with_toolcall(
context, explorer_api_url, proxy_url
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_streaming_response_with_tool_call(
context, explorer_api_url, proxy_url, push_to_explorer
):
"""Test the chat completion with streaming for the weather agent."""
weather_agent = WeatherAgent(proxy_url)
weather_agent = WeatherAgent(proxy_url, push_to_explorer)
query = "Tell me the weather for New York"
city = "new york"
@@ -226,104 +255,109 @@ async def test_streaming_response_with_toolcall(
assert response[0][1].type == "tool_use"
assert response[0][1].name == "get_weather"
assert city in response[0][1].input["location"].lower()
assert response[1][0].type == "text"
assert city in response[1][0].text.lower()
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert city in trace_messages[4]["content"].lower()
if push_to_explorer:
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert (
city
in trace_messages[2]["tool_calls"][0]["function"]["arguments"][
"location"
].lower()
)
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
assert city in trace_messages[4]["content"].lower()
async def test_response_with_toolcall_with_image(
context, explorer_api_url, proxy_url
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_response_with_tool_call_with_image(
context, explorer_api_url, proxy_url, push_to_explorer
):
weatherAgent = WeatherAgent(proxy_url)
"""Test the chat completion with image for the weather agent."""
weather_agent = WeatherAgent(proxy_url, push_to_explorer)
image_path1 = Path(__file__).parent.parent / "images" / "new-york.jpeg"
image_path2 = Path(__file__).parent.parent / "images" / "two-cats.png"
image_path = Path(__file__).parent.parent / "images" / "new-york.jpeg"
image1 = open(image_path1, "rb")
image2 = open(image_path2, "rb")
base64_image1 = base64.b64encode(image1.read()).decode("utf-8")
base64_image2 = base64.b64encode(image2.read()).decode("utf-8")
query = "get the weather in the city of these images"
city = "new york"
messages = [
{
"role": "user", "content": [
{
"type": "text",
"text": query,
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": base64_image1,
}
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image2,
}
},
]
}
]
response = weatherAgent.get_response(messages)
assert response is not None
assert response[0].role == "assistant"
assert response[0].stop_reason == "tool_use"
assert response[0].content[0].type == "text"
assert response[0].content[1].type == "tool_use"
assert city in response[0].content[1].input["location"].lower()
with image_path.open("rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
query = "get the weather in the city of this image"
city = "new york"
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": query},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": base64_image,
},
},
],
}
]
response = weather_agent.get_response(messages)
assert response is not None
assert response[0].role == "assistant"
assert response[0].stop_reason == "tool_use"
assert response[0].content[0].type == "text"
assert response[0].content[1].type == "tool_use"
assert city in response[0].content[1].input["location"].lower()
assert response[1].role == "assistant"
assert response[1].stop_reason == "end_turn"
assert response[1].role == "assistant"
assert response[1].stop_reason == "end_turn"
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weatherAgent.dataset_name}/traces"
)
traces = await traces_response.json()
if push_to_explorer:
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
)
traces = await traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
trace = traces[-1]
trace_id = trace["id"]
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[1]["role"] == "assistant"
assert city in trace_messages[1]["content"].lower()
assert trace_messages[2]["role"] == "assistant"
assert (
trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
)
assert (
city
in trace_messages[2]["tool_calls"][0]["function"]["arguments"][
"location"
].lower()
)
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
@@ -1,106 +1,116 @@
import anthropic
import os
from httpx import Client
"""Tests for the Anthropic API without tool call."""
import datetime
import pytest
import os
import sys
import anthropic
import pytest
from httpx import Client
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from util import * # needed for pytest fixtures
pytest_plugins = ("pytest_asyncio")
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_response_without_toolcall(
context, explorer_api_url,proxy_url
pytest_plugins = ("pytest_asyncio",)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_response_without_tool_call(
context, explorer_api_url, proxy_url, push_to_explorer
):
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
"""Test the Anthropic proxy without tool calling."""
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
)
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/anthropic",
)
cities = ["zurich", "new york", "london"]
queries = [
"Can you introduce Zurich, Switzerland within 200 words?",
"Tell me the history of New York within 100 words?",
"How's the weather in London next week?"
"How's the weather in London next week?",
]
# Process each query
responses = []
for query in queries:
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=[{"role": "user", "content": query}],
)
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=[{"role": "user", "content": query}],
)
response_text = response.content[0].text
responses.append(response_text)
assert response_text is not None
assert cities[queries.index(query)] in response_text.lower()
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == len(queries)
for index,trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
if push_to_explorer:
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
trace = await trace_response.json()
assert trace["messages"] == [
{
"role": "user",
"content": queries[index]
},
{
"role": "assistant",
"content": responses[index]
}
]
traces = await traces_response.json()
assert len(traces) == len(queries)
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
async def test_streaming_response_without_toolcall(
context,
explorer_api_url,
proxy_url
):
for index, trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
assert trace["messages"] == [
{"role": "user", "content": queries[index]},
{"role": "assistant", "content": responses[index]},
]
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_streaming_response_without_tool_call(
context, explorer_api_url, proxy_url, push_to_explorer
):
"""Test the Anthropic proxy without tool calling."""
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
client = anthropic.Anthropic(
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
)
http_client=Client(
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/anthropic",
)
cities = ["zurich", "new york", "london"]
queries = [
"Can you introduce Zurich, Switzerland within 200 words?",
"Tell me the history of New York within 100 words?",
"How's the weather in London next week?"
"How's the weather in London next week?",
]
# Process each query
responses = []
for index,query in enumerate(queries):
messages = [
{
"role": "user",
"content": query
}
]
for index, query in enumerate(queries):
messages = [{"role": "user", "content": query}]
response_text = ""
with client.messages.stream(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
@@ -113,26 +123,21 @@ async def test_streaming_response_without_toolcall(
assert response_text is not None
assert cities[queries.index(query)] in response_text.lower()
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == len(queries)
for index,trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
if push_to_explorer:
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
trace = await trace_response.json()
assert trace["messages"] == [
{
"role": "user",
"content": queries[index]
},
{
"role": "assistant",
"content": responses[index]
}
]
traces = await traces_response.json()
assert len(traces) == len(queries)
for index, trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
assert trace["messages"] == [
{"role": "user", "content": queries[index]},
{"role": "assistant", "content": responses[index]},
]
+54 -45
View File
@@ -7,6 +7,7 @@ import uuid
import pytest
from httpx import Client
# add tests folder (parent) to sys.path
from openai import OpenAI
@@ -18,8 +19,9 @@ pytest_plugins = ("pytest_asyncio",)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_chat_completion_with_tool_call_without_streaming(
context, explorer_api_url, proxy_url
context, explorer_api_url, proxy_url, push_to_explorer
):
"""
Test the chat completions proxy calls with tool calling and response processing
@@ -33,7 +35,9 @@ async def test_chat_completion_with_tool_call_without_streaming(
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/openai",
)
chat_response = client.chat.completions.create(
@@ -96,36 +100,38 @@ async def test_chat_completion_with_tool_call_without_streaming(
)
assert "15°C" in chat_response_final.choices[0].message.content
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
if push_to_explorer:
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Verify the trace messages
expected_messages = history + [
{
"role": "assistant",
"content": chat_response_final.choices[0].message.content,
}
]
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
)
assert trace["messages"] == expected_messages
# Verify the trace messages
expected_messages = history + [
{
"role": "assistant",
"content": chat_response_final.choices[0].message.content,
}
]
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
)
assert trace["messages"] == expected_messages
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_chat_completion_with_tool_call_with_streaming(
context, explorer_api_url, proxy_url
context, explorer_api_url, proxy_url, push_to_explorer
):
"""
Test the chat completions proxy calls with tool calling and response processing
@@ -139,7 +145,9 @@ async def test_chat_completion_with_tool_call_with_streaming(
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/openai",
)
chat_response = client.chat.completions.create(
@@ -209,23 +217,24 @@ async def test_chat_completion_with_tool_call_with_streaming(
if chunk.choices and chunk.choices[0].delta.content:
final_response["content"] += chunk.choices[0].delta.content
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
if push_to_explorer:
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Verify the trace messages
expected_messages = history + [final_response]
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
)
assert trace["messages"] == expected_messages
# Verify the trace messages
expected_messages = history + [final_response]
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
)
assert trace["messages"] == expected_messages
+76 -57
View File
@@ -21,8 +21,13 @@ pytest_plugins = ("pytest_asyncio",)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("do_stream", [True, False])
async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream):
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_chat_completion(
context, explorer_api_url, proxy_url, do_stream, push_to_explorer
):
"""Test the chat completions proxy calls without tool calling."""
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
@@ -32,7 +37,9 @@ async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream):
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/openai",
)
chat_response = client.chat.completions.create(
@@ -53,35 +60,39 @@ async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream):
assert "PARIS" in full_response.upper()
expected_assistant_message = full_response
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
if push_to_explorer:
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Verify the trace messages
assert trace["messages"] == [
{
"role": "user",
"content": "What is the capital of France?",
},
{
"role": "assistant",
"content": expected_assistant_message,
},
]
# Verify the trace messages
assert trace["messages"] == [
{
"role": "user",
"content": "What is the capital of France?",
},
{
"role": "assistant",
"content": expected_assistant_message,
},
]
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
async def test_chat_completion_with_image(context, explorer_api_url, proxy_url):
@pytest.mark.parametrize("push_to_explorer", [True, False])
async def test_chat_completion_with_image(
context, explorer_api_url, proxy_url, push_to_explorer
):
"""Test the chat completions proxy works with image."""
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
@@ -91,7 +102,9 @@ async def test_chat_completion_with_image(context, explorer_api_url, proxy_url):
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
if push_to_explorer
else f"{proxy_url}/api/v1/proxy/openai",
)
image_path = Path(__file__).parent.parent / "images" / "two-cats.png"
with image_path.open("rb") as image_file:
@@ -121,37 +134,43 @@ async def test_chat_completion_with_image(context, explorer_api_url, proxy_url):
assert "TWO" in chat_response.choices[0].message.content.upper()
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
if push_to_explorer:
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
)
traces = await traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
)
trace = await trace_response.json()
# Verify the trace messages
assert trace["messages"] == [
{
"role": "user",
"content": [
{"type": "text", "text": "How many cats are there in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + base64_image},
},
],
},
{
"role": "assistant",
"content": chat_response.choices[0].message.content,
},
]
# Verify the trace messages
assert trace["messages"] == [
{
"role": "user",
"content": [
{
"type": "text",
"text": "How many cats are there in this image?",
},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64," + base64_image
},
},
],
},
{
"role": "assistant",
"content": chat_response.choices[0].message.content,
},
]
@pytest.mark.skip(reason="Skipping this test: OpenAI error scenario")