mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-06 13:13:55 +02:00
Add API endpoints so that the Proxy can be used without pushing to Explorer.
This commit is contained in:
+34
-26
@@ -1,10 +1,10 @@
|
||||
"""Proxy service to forward requests to the Anthropic APIs"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
|
||||
from utils.explorer import push_trace
|
||||
@@ -36,9 +36,13 @@ def validate_headers(x_api_key: str = Header(None)):
|
||||
"/{dataset_name}/anthropic/v1/messages",
|
||||
dependencies=[Depends(validate_headers)],
|
||||
)
|
||||
@proxy.post(
|
||||
"/anthropic/v1/messages",
|
||||
dependencies=[Depends(validate_headers)],
|
||||
)
|
||||
async def anthropic_v1_messages_proxy(
|
||||
dataset_name: str,
|
||||
request: Request,
|
||||
dataset_name: str = None,
|
||||
):
|
||||
"""Proxy calls to the Anthropic APIs"""
|
||||
headers = {
|
||||
@@ -77,17 +81,10 @@ async def anthropic_v1_messages_proxy(
|
||||
client, anthropic_request, dataset_name, invariant_authorization
|
||||
)
|
||||
else:
|
||||
try:
|
||||
response = await client.send(anthropic_request)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to fetch response from Anthropic: {response.text}, got error{e}",
|
||||
)
|
||||
await handle_non_streaming_response(
|
||||
response = await client.send(anthropic_request)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
@@ -116,7 +113,7 @@ async def handle_non_streaming_response(
|
||||
dataset_name: str,
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
):
|
||||
) -> Response:
|
||||
"""Handles non-streaming Anthropic responses"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
@@ -131,20 +128,28 @@ async def handle_non_streaming_response(
|
||||
detail=json_response.get("error", "Unknown error from Anthropic"),
|
||||
)
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
json_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
if dataset_name:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
json_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
status_code=response.status_code,
|
||||
media_type="application/json",
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
client: httpx.AsyncClient,
|
||||
anthropic_request: httpx.Request,
|
||||
dataset_name: str,
|
||||
dataset_name: Optional[str],
|
||||
invariant_authorization: str,
|
||||
) -> StreamingResponse:
|
||||
"""Handles streaming Anthropic responses"""
|
||||
formatted_invariant_response = []
|
||||
|
||||
response = await client.send(anthropic_request, stream=True)
|
||||
@@ -165,13 +170,13 @@ async def handle_streaming_response(
|
||||
yield chunk
|
||||
|
||||
process_chunk_text(chunk_decode, formatted_invariant_response)
|
||||
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
formatted_invariant_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
if dataset_name:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
formatted_invariant_response[-1],
|
||||
json.loads(anthropic_request.content),
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
generator = event_generator()
|
||||
|
||||
@@ -193,6 +198,7 @@ def process_chunk_text(chunk_decode, formatted_invariant_response):
|
||||
|
||||
|
||||
def update_formatted_invariant_response(text_json, formatted_invariant_response):
|
||||
"""Update the formatted_invariant_response based on the text_json"""
|
||||
if text_json.get("type") == MESSAGE_START:
|
||||
message = text_json.get("message")
|
||||
formatted_invariant_response.append(
|
||||
@@ -252,6 +258,7 @@ def anthropic_to_invariant_messages(
|
||||
|
||||
|
||||
def handle_user_message(message, keep_empty_tool_response):
|
||||
"""Handle the user message from the Anthropic API"""
|
||||
output = []
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
@@ -298,6 +305,7 @@ def handle_user_message(message, keep_empty_tool_response):
|
||||
|
||||
|
||||
def handle_assistant_message(message):
|
||||
"""Handle the assistant message from the Anthropic API"""
|
||||
output = []
|
||||
if isinstance(message["content"], list):
|
||||
for sub_message in message["content"]:
|
||||
|
||||
+20
-14
@@ -1,7 +1,7 @@
|
||||
"""Proxy service to forward requests to the OpenAI APIs"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
@@ -26,9 +26,13 @@ def validate_headers(authorization: str = Header(None)):
|
||||
"/{dataset_name}/openai/chat/completions",
|
||||
dependencies=[Depends(validate_headers)],
|
||||
)
|
||||
@proxy.post(
|
||||
"/openai/chat/completions",
|
||||
dependencies=[Depends(validate_headers)],
|
||||
)
|
||||
async def openai_chat_completions_proxy(
|
||||
request: Request,
|
||||
dataset_name: str,
|
||||
dataset_name: str = None,
|
||||
) -> Response:
|
||||
"""Proxy calls to the OpenAI APIs"""
|
||||
|
||||
@@ -92,7 +96,7 @@ async def openai_chat_completions_proxy(
|
||||
async def stream_response(
|
||||
client: httpx.AsyncClient,
|
||||
open_ai_request: httpx.Request,
|
||||
dataset_name: str,
|
||||
dataset_name: Optional[str],
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> Response:
|
||||
@@ -150,12 +154,13 @@ async def stream_response(
|
||||
)
|
||||
|
||||
# Send full merged response to the explorer
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
merged_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
if dataset_name:
|
||||
await push_to_explorer(
|
||||
dataset_name,
|
||||
merged_response,
|
||||
request_body_json,
|
||||
invariant_authorization,
|
||||
)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@@ -318,10 +323,10 @@ async def push_to_explorer(
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
response: httpx.Response,
|
||||
dataset_name: str,
|
||||
dataset_name: Optional[str],
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
):
|
||||
) -> Response:
|
||||
"""Handles non-streaming OpenAI responses"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
@@ -335,9 +340,10 @@ async def handle_non_streaming_response(
|
||||
status_code=response.status_code,
|
||||
detail=json_response.get("error", "Unknown error from OpenAI API"),
|
||||
)
|
||||
await push_to_explorer(
|
||||
dataset_name, json_response, request_body_json, invariant_authorization
|
||||
)
|
||||
if dataset_name:
|
||||
await push_to_explorer(
|
||||
dataset_name, json_response, request_body_json, invariant_authorization
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
|
||||
@@ -1,28 +1,45 @@
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
import anthropic
|
||||
from httpx import Client
|
||||
import datetime
|
||||
"""Test the Anthropic proxy with Invariant key in the ANTHROPIC_API_KEY."""
|
||||
|
||||
import pytest
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from httpx import Client
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from util import * # needed for pytest fixtures
|
||||
|
||||
pytest_plugins = ("pytest_asyncio")
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_header(
|
||||
context, proxy_url, explorer_api_url
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_proxy_with_invariant_key_in_anthropic_key(
|
||||
context, proxy_url, explorer_api_url, push_to_explorer
|
||||
):
|
||||
"""Test the Anthropic proxy with Invariant key in the Anthropic key"""
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
dataset_name = "claude_header_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": anthropic_api_key + "|invariant-auth: <not needed for test>"}):
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
+ "|invariant-auth: <not needed for test>"
|
||||
},
|
||||
):
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(),
|
||||
base_url = f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
|
||||
http_client=Client(),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/anthropic",
|
||||
)
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
@@ -30,33 +47,31 @@ async def test_header(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words.",
|
||||
}
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
response_text = response.content[0].text
|
||||
assert "zurich" in response_text.lower()
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
|
||||
trace_id = traces[0]["id"]
|
||||
get_trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await get_trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
}
|
||||
]
|
||||
trace_id = traces[0]["id"]
|
||||
get_trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await get_trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words.",
|
||||
},
|
||||
{"role": "assistant", "content": response_text},
|
||||
]
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""Test the Anthropic messages API with tool call for the weather agent."""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
from typing import Dict
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from httpx import Client
|
||||
import base64
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
@@ -17,16 +20,20 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
class WeatherAgent:
|
||||
def __init__(self,proxy_url):
|
||||
"""Weather agent to get the current weather in a given location."""
|
||||
|
||||
def __init__(self, proxy_url, push_to_explorer):
|
||||
self.dataset_name = "claude_weather_agent_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{self.dataset_name}/anthropic",
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{self.dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/anthropic",
|
||||
)
|
||||
self.get_weather_function = {
|
||||
"name": "get_weather",
|
||||
@@ -48,7 +55,7 @@ class WeatherAgent:
|
||||
},
|
||||
}
|
||||
|
||||
def get_response(self, messages: str) -> Dict:
|
||||
def get_response(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Get the response from the agent for a given user query for weather.
|
||||
"""
|
||||
@@ -58,7 +65,7 @@ class WeatherAgent:
|
||||
tools=[self.get_weather_function],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=messages
|
||||
messages=messages,
|
||||
)
|
||||
response_list.append(response)
|
||||
# If there's tool call, Extract the tool call parameters from the response
|
||||
@@ -81,18 +88,19 @@ class WeatherAgent:
|
||||
)
|
||||
else:
|
||||
return response_list
|
||||
|
||||
def get_streaming_response(self, messages: str) -> Dict:
|
||||
|
||||
def get_streaming_response(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""Get streaming response from the agent for a given user query for weather."""
|
||||
response_list = []
|
||||
|
||||
def clean_quotes(text):
|
||||
# Convert \' to '
|
||||
text = text.replace("\'", "'")
|
||||
text = text.replace("'", "'")
|
||||
# Convert \" to "
|
||||
text = text.replace('\"', '"')
|
||||
text = text.replace('"', '"')
|
||||
text = text.replace("\n", " ")
|
||||
return text
|
||||
|
||||
|
||||
while True:
|
||||
json_data = ""
|
||||
content = []
|
||||
@@ -108,26 +116,35 @@ class WeatherAgent:
|
||||
current_block = event.content_block
|
||||
current_text = ""
|
||||
elif isinstance(event, anthropic.types.RawContentBlockDeltaEvent):
|
||||
if hasattr(event.delta, 'text'):
|
||||
if hasattr(event.delta, "text"):
|
||||
# Accumulate text for TextBlock
|
||||
current_text += clean_quotes(event.delta.text)
|
||||
elif hasattr(event.delta, 'partial_json'):
|
||||
elif hasattr(event.delta, "partial_json"):
|
||||
# Accumulate JSON for ToolUseBlock
|
||||
json_data += clean_quotes(event.delta.partial_json)
|
||||
current_text += clean_quotes(event.delta.partial_json)
|
||||
elif isinstance(event, anthropic.types.RawContentBlockStopEvent):
|
||||
# Block is complete, add it to content
|
||||
if current_block.type == 'text':
|
||||
content.append(anthropic.types.TextBlock(citations=None, text=current_text, type="text"))
|
||||
elif current_block.type == 'tool_use':
|
||||
if current_block.type == "text":
|
||||
content.append(
|
||||
anthropic.types.ToolUseBlock(id=current_block.id,
|
||||
input=json.loads(current_text),
|
||||
name=current_block.name,
|
||||
type="tool_use")
|
||||
anthropic.types.TextBlock(
|
||||
citations=None, text=current_text, type="text"
|
||||
)
|
||||
)
|
||||
response_list.append(content)
|
||||
if isinstance(event, anthropic.types.RawMessageStopEvent) and event.message.stop_reason == "tool_use":
|
||||
elif current_block.type == "tool_use":
|
||||
content.append(
|
||||
anthropic.types.ToolUseBlock(
|
||||
id=current_block.id,
|
||||
input=json.loads(current_text),
|
||||
name=current_block.name,
|
||||
type="tool_use",
|
||||
)
|
||||
)
|
||||
response_list.append(content)
|
||||
if (
|
||||
isinstance(event, anthropic.types.RawMessageStopEvent)
|
||||
and event.message.stop_reason == "tool_use"
|
||||
):
|
||||
tool_call_params = json.loads(json_data)
|
||||
tool_call_result = self.get_weather(tool_call_params["location"])
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
@@ -148,21 +165,25 @@ class WeatherAgent:
|
||||
|
||||
def get_weather(self, location: str):
|
||||
"""Get the current weather in a given location using latitude and longitude."""
|
||||
response = f'''Weather in {location}:
|
||||
response = f"""Weather in {location}:
|
||||
Good morning! Expect overcast skies with intermittent showers throughout the day.
|
||||
Temperatures will range from a cool 15°C in the early hours to around 19°C by mid-afternoon.
|
||||
Light winds from the northeast at about 10 km/h will keep conditions mild.
|
||||
It might be a good idea to carry an umbrella if you’re heading out. Stay dry and have a great day!
|
||||
'''
|
||||
"""
|
||||
return response
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_response_with_toolcall(
|
||||
context, explorer_api_url, proxy_url
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_response_with_tool_call(
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
"""Test the chat completion without streaming for the weather agent."""
|
||||
|
||||
weather_agent = WeatherAgent(proxy_url)
|
||||
|
||||
weather_agent = WeatherAgent(proxy_url, push_to_explorer)
|
||||
|
||||
query = "Tell me the weather for New York"
|
||||
|
||||
@@ -183,38 +204,46 @@ async def test_response_with_toolcall(
|
||||
assert city in response[1].content[0].text.lower()
|
||||
responses.append(response)
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert city in trace_messages[4]["content"].lower()
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert (
|
||||
city
|
||||
in trace_messages[2]["tool_calls"][0]["function"]["arguments"][
|
||||
"location"
|
||||
].lower()
|
||||
)
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert city in trace_messages[4]["content"].lower()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_streaming_response_with_toolcall(
|
||||
context, explorer_api_url, proxy_url
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_streaming_response_with_tool_call(
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
"""Test the chat completion with streaming for the weather agent."""
|
||||
weather_agent = WeatherAgent(proxy_url)
|
||||
weather_agent = WeatherAgent(proxy_url, push_to_explorer)
|
||||
|
||||
query = "Tell me the weather for New York"
|
||||
city = "new york"
|
||||
@@ -226,104 +255,109 @@ async def test_streaming_response_with_toolcall(
|
||||
assert response[0][1].type == "tool_use"
|
||||
assert response[0][1].name == "get_weather"
|
||||
assert city in response[0][1].input["location"].lower()
|
||||
|
||||
|
||||
assert response[1][0].type == "text"
|
||||
assert city in response[1][0].text.lower()
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert city in trace_messages[4]["content"].lower()
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert (
|
||||
city
|
||||
in trace_messages[2]["tool_calls"][0]["function"]["arguments"][
|
||||
"location"
|
||||
].lower()
|
||||
)
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
assert city in trace_messages[4]["content"].lower()
|
||||
|
||||
|
||||
async def test_response_with_toolcall_with_image(
|
||||
context, explorer_api_url, proxy_url
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_response_with_tool_call_with_image(
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
weatherAgent = WeatherAgent(proxy_url)
|
||||
"""Test the chat completion with image for the weather agent."""
|
||||
weather_agent = WeatherAgent(proxy_url, push_to_explorer)
|
||||
|
||||
image_path1 = Path(__file__).parent.parent / "images" / "new-york.jpeg"
|
||||
image_path2 = Path(__file__).parent.parent / "images" / "two-cats.png"
|
||||
image_path = Path(__file__).parent.parent / "images" / "new-york.jpeg"
|
||||
|
||||
image1 = open(image_path1, "rb")
|
||||
image2 = open(image_path2, "rb")
|
||||
base64_image1 = base64.b64encode(image1.read()).decode("utf-8")
|
||||
base64_image2 = base64.b64encode(image2.read()).decode("utf-8")
|
||||
query = "get the weather in the city of these images"
|
||||
city = "new york"
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
|
||||
{
|
||||
"type": "text",
|
||||
"text": query,
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": base64_image1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image2,
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
response = weatherAgent.get_response(messages)
|
||||
assert response is not None
|
||||
assert response[0].role == "assistant"
|
||||
assert response[0].stop_reason == "tool_use"
|
||||
assert response[0].content[0].type == "text"
|
||||
assert response[0].content[1].type == "tool_use"
|
||||
assert city in response[0].content[1].input["location"].lower()
|
||||
with image_path.open("rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
query = "get the weather in the city of this image"
|
||||
city = "new york"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": query},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": base64_image,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = weather_agent.get_response(messages)
|
||||
assert response is not None
|
||||
assert response[0].role == "assistant"
|
||||
assert response[0].stop_reason == "tool_use"
|
||||
assert response[0].content[0].type == "text"
|
||||
assert response[0].content[1].type == "tool_use"
|
||||
assert city in response[0].content[1].input["location"].lower()
|
||||
|
||||
assert response[1].role == "assistant"
|
||||
assert response[1].stop_reason == "end_turn"
|
||||
assert response[1].role == "assistant"
|
||||
assert response[1].stop_reason == "end_turn"
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weatherAgent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower()
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
assert city in trace_messages[1]["content"].lower()
|
||||
assert trace_messages[2]["role"] == "assistant"
|
||||
assert (
|
||||
trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
)
|
||||
assert (
|
||||
city
|
||||
in trace_messages[2]["tool_calls"][0]["function"]["arguments"][
|
||||
"location"
|
||||
].lower()
|
||||
)
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
|
||||
@@ -1,106 +1,116 @@
|
||||
import anthropic
|
||||
import os
|
||||
from httpx import Client
|
||||
"""Tests for the Anthropic API without tool call."""
|
||||
|
||||
import datetime
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from httpx import Client
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from util import * # needed for pytest fixtures
|
||||
|
||||
pytest_plugins = ("pytest_asyncio")
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_response_without_toolcall(
|
||||
context, explorer_api_url,proxy_url
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_response_without_tool_call(
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
|
||||
"""Test the Anthropic proxy without tool calling."""
|
||||
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/anthropic",
|
||||
)
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
queries = [
|
||||
"Can you introduce Zurich, Switzerland within 200 words?",
|
||||
"Tell me the history of New York within 100 words?",
|
||||
"How's the weather in London next week?"
|
||||
"How's the weather in London next week?",
|
||||
]
|
||||
|
||||
# Process each query
|
||||
responses = []
|
||||
for query in queries:
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
)
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
)
|
||||
response_text = response.content[0].text
|
||||
responses.append(response_text)
|
||||
assert response_text is not None
|
||||
assert cities[queries.index(query)] in response_text.lower()
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
for index,trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": queries[index]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": responses[index]
|
||||
}
|
||||
]
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
async def test_streaming_response_without_toolcall(
|
||||
context,
|
||||
explorer_api_url,
|
||||
proxy_url
|
||||
):
|
||||
for index, trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{"role": "user", "content": queries[index]},
|
||||
{"role": "assistant", "content": responses[index]},
|
||||
]
|
||||
|
||||
dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY","None")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_streaming_response_without_tool_call(
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
"""Test the Anthropic proxy without tool calling."""
|
||||
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
|
||||
http_client=Client(
|
||||
headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"},
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/anthropic",
|
||||
)
|
||||
|
||||
cities = ["zurich", "new york", "london"]
|
||||
queries = [
|
||||
"Can you introduce Zurich, Switzerland within 200 words?",
|
||||
"Tell me the history of New York within 100 words?",
|
||||
"How's the weather in London next week?"
|
||||
"How's the weather in London next week?",
|
||||
]
|
||||
# Process each query
|
||||
responses = []
|
||||
for index,query in enumerate(queries):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
]
|
||||
for index, query in enumerate(queries):
|
||||
messages = [{"role": "user", "content": query}]
|
||||
response_text = ""
|
||||
|
||||
|
||||
with client.messages.stream(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
@@ -113,26 +123,21 @@ async def test_streaming_response_without_toolcall(
|
||||
assert response_text is not None
|
||||
assert cities[queries.index(query)] in response_text.lower()
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
for index,trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": queries[index]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": responses[index]
|
||||
}
|
||||
]
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
for index, trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{"role": "user", "content": queries[index]},
|
||||
{"role": "assistant", "content": responses[index]},
|
||||
]
|
||||
|
||||
@@ -7,6 +7,7 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from httpx import Client
|
||||
|
||||
# add tests folder (parent) to sys.path
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -18,8 +19,9 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_chat_completion_with_tool_call_without_streaming(
|
||||
context, explorer_api_url, proxy_url
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
"""
|
||||
Test the chat completions proxy calls with tool calling and response processing
|
||||
@@ -33,7 +35,9 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/openai",
|
||||
)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
@@ -96,36 +100,38 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
)
|
||||
assert "15°C" in chat_response_final.choices[0].message.content
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
if push_to_explorer:
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
expected_messages = history + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": chat_response_final.choices[0].message.content,
|
||||
}
|
||||
]
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
assert trace["messages"] == expected_messages
|
||||
# Verify the trace messages
|
||||
expected_messages = history + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": chat_response_final.choices[0].message.content,
|
||||
}
|
||||
]
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
assert trace["messages"] == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_chat_completion_with_tool_call_with_streaming(
|
||||
context, explorer_api_url, proxy_url
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
"""
|
||||
Test the chat completions proxy calls with tool calling and response processing
|
||||
@@ -139,7 +145,9 @@ async def test_chat_completion_with_tool_call_with_streaming(
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/openai",
|
||||
)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
@@ -209,23 +217,24 @@ async def test_chat_completion_with_tool_call_with_streaming(
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
final_response["content"] += chunk.choices[0].delta.content
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
if push_to_explorer:
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
expected_messages = history + [final_response]
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
assert trace["messages"] == expected_messages
|
||||
# Verify the trace messages
|
||||
expected_messages = history + [final_response]
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
assert trace["messages"] == expected_messages
|
||||
|
||||
@@ -21,8 +21,13 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream):
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_chat_completion(
|
||||
context, explorer_api_url, proxy_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test the chat completions proxy calls without tool calling."""
|
||||
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
|
||||
|
||||
@@ -32,7 +37,9 @@ async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream):
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/openai",
|
||||
)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
@@ -53,35 +60,39 @@ async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream):
|
||||
assert "PARIS" in full_response.upper()
|
||||
expected_assistant_message = full_response
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
if push_to_explorer:
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": expected_assistant_message,
|
||||
},
|
||||
]
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": expected_assistant_message,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
async def test_chat_completion_with_image(context, explorer_api_url, proxy_url):
|
||||
@pytest.mark.parametrize("push_to_explorer", [True, False])
|
||||
async def test_chat_completion_with_image(
|
||||
context, explorer_api_url, proxy_url, push_to_explorer
|
||||
):
|
||||
"""Test the chat completions proxy works with image."""
|
||||
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
|
||||
|
||||
@@ -91,7 +102,9 @@ async def test_chat_completion_with_image(context, explorer_api_url, proxy_url):
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/openai",
|
||||
)
|
||||
image_path = Path(__file__).parent.parent / "images" / "two-cats.png"
|
||||
with image_path.open("rb") as image_file:
|
||||
@@ -121,37 +134,43 @@ async def test_chat_completion_with_image(context, explorer_api_url, proxy_url):
|
||||
|
||||
assert "TWO" in chat_response.choices[0].message.content.upper()
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
if push_to_explorer:
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "How many cats are there in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64," + base64_image},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": chat_response.choices[0].message.content,
|
||||
},
|
||||
]
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "How many cats are there in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64," + base64_image
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": chat_response.choices[0].message.content,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test: OpenAI error scenario")
|
||||
|
||||
Reference in New Issue
Block a user