From 6afbcd3ea0efff7a558a750deff4bbd545063190 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 25 Feb 2025 21:21:24 +0100 Subject: [PATCH] Add API endpoints so that the Proxy can be used without pushing to Explorer. --- proxy/routes/anthropic.py | 60 ++-- proxy/routes/open_ai.py | 34 +- ...est_anthropic_header_with_invariant_key.py | 89 +++-- .../test_anthropic_with_tool_call.py | 334 ++++++++++-------- .../test_anthropic_without_tool_call.py | 181 +++++----- tests/open_ai/test_chat_with_tool_call.py | 99 +++--- tests/open_ai/test_chat_without_tool_calls.py | 133 ++++--- 7 files changed, 513 insertions(+), 417 deletions(-) diff --git a/proxy/routes/anthropic.py b/proxy/routes/anthropic.py index 23abc60..98d06db 100644 --- a/proxy/routes/anthropic.py +++ b/proxy/routes/anthropic.py @@ -1,10 +1,10 @@ """Proxy service to forward requests to the Anthropic APIs""" import json -from typing import Any +from typing import Any, Optional import httpx -from fastapi import APIRouter, Depends, Header, HTTPException, Request +from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS from utils.explorer import push_trace @@ -36,9 +36,13 @@ def validate_headers(x_api_key: str = Header(None)): "/{dataset_name}/anthropic/v1/messages", dependencies=[Depends(validate_headers)], ) +@proxy.post( + "/anthropic/v1/messages", + dependencies=[Depends(validate_headers)], +) async def anthropic_v1_messages_proxy( - dataset_name: str, request: Request, + dataset_name: str = None, ): """Proxy calls to the Anthropic APIs""" headers = { @@ -77,17 +81,10 @@ async def anthropic_v1_messages_proxy( client, anthropic_request, dataset_name, invariant_authorization ) else: - try: - response = await client.send(anthropic_request) - except httpx.HTTPStatusError as e: - raise HTTPException( - status_code=response.status_code, - detail=f"Failed to fetch response from Anthropic: {response.text}, got error{e}", - ) - await handle_non_streaming_response( + response = await client.send(anthropic_request) + return await handle_non_streaming_response( response, dataset_name, request_body_json, invariant_authorization ) - return response.json() async def push_to_explorer( @@ -116,7 +113,7 @@ async def handle_non_streaming_response( dataset_name: str, request_body_json: dict[str, Any], invariant_authorization: str, -): +) -> Response: """Handles non-streaming Anthropic responses""" try: json_response = response.json() @@ -131,20 +128,28 @@ async def handle_non_streaming_response( detail=json_response.get("error", "Unknown error from Anthropic"), ) # Only push the trace to explorer if the last message is an end turn message - await push_to_explorer( - dataset_name, - json_response, - request_body_json, - invariant_authorization, + if dataset_name: + await push_to_explorer( + dataset_name, + json_response, + request_body_json, + invariant_authorization, + ) + return Response( + content=json.dumps(json_response), + status_code=response.status_code, + media_type="application/json", + headers=dict(response.headers), ) async def handle_streaming_response( client: httpx.AsyncClient, anthropic_request: httpx.Request, - dataset_name: str, + dataset_name: Optional[str], invariant_authorization: str, ) -> StreamingResponse: + """Handles streaming Anthropic responses""" formatted_invariant_response = [] response = await client.send(anthropic_request, stream=True) @@ -165,13 +170,13 @@ async def handle_streaming_response( yield chunk process_chunk_text(chunk_decode, formatted_invariant_response) - - await push_to_explorer( - dataset_name, - formatted_invariant_response[-1], - json.loads(anthropic_request.content), - invariant_authorization, - ) + if dataset_name: + await push_to_explorer( + dataset_name, + formatted_invariant_response[-1], + json.loads(anthropic_request.content), + invariant_authorization, + ) generator = event_generator() @@ -193,6 +198,7 @@ def process_chunk_text(chunk_decode, formatted_invariant_response): def update_formatted_invariant_response(text_json, formatted_invariant_response): + """Update the formatted_invariant_response based on the text_json""" if text_json.get("type") == MESSAGE_START: message = text_json.get("message") formatted_invariant_response.append( @@ -252,6 +258,7 @@ def anthropic_to_invariant_messages( def handle_user_message(message, keep_empty_tool_response): + """Handle the user message from the Anthropic API""" output = [] content = message["content"] if isinstance(content, list): @@ -298,6 +305,7 @@ def handle_user_message(message, keep_empty_tool_response): def handle_assistant_message(message): + """Handle the assistant message from the Anthropic API""" output = [] if isinstance(message["content"], list): for sub_message in message["content"]: diff --git a/proxy/routes/open_ai.py b/proxy/routes/open_ai.py index ad8578e..11569a1 100644 --- a/proxy/routes/open_ai.py +++ b/proxy/routes/open_ai.py @@ -1,7 +1,7 @@ """Proxy service to forward requests to the OpenAI APIs""" import json -from typing import Any +from typing import Any, Optional import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response @@ -26,9 +26,13 @@ def validate_headers(authorization: str = Header(None)): "/{dataset_name}/openai/chat/completions", dependencies=[Depends(validate_headers)], ) +@proxy.post( + "/openai/chat/completions", + dependencies=[Depends(validate_headers)], +) async def openai_chat_completions_proxy( request: Request, - dataset_name: str, + dataset_name: str = None, ) -> Response: """Proxy calls to the OpenAI APIs""" @@ -92,7 +96,7 @@ async def openai_chat_completions_proxy( async def stream_response( client: httpx.AsyncClient, open_ai_request: httpx.Request, - dataset_name: str, + dataset_name: Optional[str], request_body_json: dict[str, Any], invariant_authorization: str, ) -> Response: @@ -150,12 +154,13 @@ async def stream_response( ) # Send full merged response to the explorer - await push_to_explorer( - dataset_name, - merged_response, - request_body_json, - invariant_authorization, - ) + if dataset_name: + await push_to_explorer( + dataset_name, + merged_response, + request_body_json, + invariant_authorization, + ) return StreamingResponse(event_generator(), media_type="text/event-stream") @@ -318,10 +323,10 @@ async def push_to_explorer( async def handle_non_streaming_response( response: httpx.Response, - dataset_name: str, + dataset_name: Optional[str], request_body_json: dict[str, Any], invariant_authorization: str, -): +) -> Response: """Handles non-streaming OpenAI responses""" try: json_response = response.json() @@ -335,9 +340,10 @@ async def handle_non_streaming_response( status_code=response.status_code, detail=json_response.get("error", "Unknown error from OpenAI API"), ) - await push_to_explorer( - dataset_name, json_response, request_body_json, invariant_authorization - ) + if dataset_name: + await push_to_explorer( + dataset_name, json_response, request_body_json, invariant_authorization + ) return Response( content=json.dumps(json_response), diff --git a/tests/anthropic/test_anthropic_header_with_invariant_key.py b/tests/anthropic/test_anthropic_header_with_invariant_key.py index 88f3728..bbe5bf2 100644 --- a/tests/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/anthropic/test_anthropic_header_with_invariant_key.py @@ -1,28 +1,45 @@ -from unittest.mock import patch -import os -import anthropic -from httpx import Client -import datetime +"""Test the Anthropic proxy with Invariant key in the ANTHROPIC_API_KEY.""" -import pytest +import datetime +import os import sys +from unittest.mock import patch + +import anthropic +import pytest +from httpx import Client + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from util import * # needed for pytest fixtures -pytest_plugins = ("pytest_asyncio") -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") -async def test_header( - context, proxy_url, explorer_api_url +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("push_to_explorer", [False, True]) +async def test_proxy_with_invariant_key_in_anthropic_key( + context, proxy_url, explorer_api_url, push_to_explorer ): + """Test the Anthropic proxy with Invariant key in the Anthropic key""" anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") dataset_name = "claude_header_test" + str( - datetime.datetime.now().strftime("%Y%m%d%H%M%S") - ) - with patch.dict(os.environ, {"ANTHROPIC_API_KEY": anthropic_api_key + "|invariant-auth: "}): + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + ) + with patch.dict( + os.environ, + { + "ANTHROPIC_API_KEY": anthropic_api_key + + "|invariant-auth: " + }, + ): client = anthropic.Anthropic( - http_client=Client(), - base_url = f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic", + http_client=Client(), + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/anthropic", ) response = client.messages.create( model="claude-3-5-sonnet-20241022", @@ -30,33 +47,31 @@ async def test_header( messages=[ { "role": "user", - "content": "Give me an introduction to Zurich, Switzerland within 200 words." + "content": "Give me an introduction to Zurich, Switzerland within 200 words.", } - ] + ], ) assert response is not None response_text = response.content[0].text assert "zurich" in response_text.lower() - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == 1 + if push_to_explorer: + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 - trace_id = traces[0]["id"] - get_trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await get_trace_response.json() - assert trace["messages"] == [ - { - "role": "user", - "content": "Give me an introduction to Zurich, Switzerland within 200 words." - }, - { - "role": "assistant", - "content": response_text - } - ] \ No newline at end of file + trace_id = traces[0]["id"] + get_trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await get_trace_response.json() + assert trace["messages"] == [ + { + "role": "user", + "content": "Give me an introduction to Zurich, Switzerland within 200 words.", + }, + {"role": "assistant", "content": response_text}, + ] diff --git a/tests/anthropic/test_anthropic_with_tool_call.py b/tests/anthropic/test_anthropic_with_tool_call.py index 87b0881..84ba9b8 100644 --- a/tests/anthropic/test_anthropic_with_tool_call.py +++ b/tests/anthropic/test_anthropic_with_tool_call.py @@ -1,13 +1,16 @@ +"""Test the Anthropic messages API with tool call for the weather agent.""" + +import base64 import datetime -import os -from typing import Dict import json +import os +import sys +from pathlib import Path +from typing import Dict, List + import anthropic import pytest from httpx import Client -import base64 -import sys -from pathlib import Path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -17,16 +20,20 @@ pytest_plugins = ("pytest_asyncio",) class WeatherAgent: - def __init__(self,proxy_url): + """Weather agent to get the current weather in a given location.""" + + def __init__(self, proxy_url, push_to_explorer): self.dataset_name = "claude_weather_agent_test" + str( datetime.datetime.now().strftime("%Y%m%d%H%M%S") ) - invariant_api_key = os.environ.get("INVARIANT_API_KEY","None") + invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") self.client = anthropic.Anthropic( http_client=Client( headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, ), - base_url=f"{proxy_url}/api/v1/proxy/{self.dataset_name}/anthropic", + base_url=f"{proxy_url}/api/v1/proxy/{self.dataset_name}/anthropic" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/anthropic", ) self.get_weather_function = { "name": "get_weather", @@ -48,7 +55,7 @@ class WeatherAgent: }, } - def get_response(self, messages: str) -> Dict: + def get_response(self, messages: List[Dict]) -> List[Dict]: """ Get the response from the agent for a given user query for weather. """ @@ -58,7 +65,7 @@ class WeatherAgent: tools=[self.get_weather_function], model="claude-3-5-sonnet-20241022", max_tokens=1024, - messages=messages + messages=messages, ) response_list.append(response) # If there's tool call, Extract the tool call parameters from the response @@ -81,18 +88,19 @@ class WeatherAgent: ) else: return response_list - - def get_streaming_response(self, messages: str) -> Dict: + + def get_streaming_response(self, messages: List[Dict]) -> List[Dict]: + """Get streaming response from the agent for a given user query for weather.""" response_list = [] def clean_quotes(text): # Convert \' to ' - text = text.replace("\'", "'") + text = text.replace("'", "'") # Convert \" to " - text = text.replace('\"', '"') + text = text.replace('"', '"') text = text.replace("\n", " ") return text - + while True: json_data = "" content = [] @@ -108,26 +116,35 @@ class WeatherAgent: current_block = event.content_block current_text = "" elif isinstance(event, anthropic.types.RawContentBlockDeltaEvent): - if hasattr(event.delta, 'text'): + if hasattr(event.delta, "text"): # Accumulate text for TextBlock current_text += clean_quotes(event.delta.text) - elif hasattr(event.delta, 'partial_json'): + elif hasattr(event.delta, "partial_json"): # Accumulate JSON for ToolUseBlock json_data += clean_quotes(event.delta.partial_json) current_text += clean_quotes(event.delta.partial_json) elif isinstance(event, anthropic.types.RawContentBlockStopEvent): # Block is complete, add it to content - if current_block.type == 'text': - content.append(anthropic.types.TextBlock(citations=None, text=current_text, type="text")) - elif current_block.type == 'tool_use': + if current_block.type == "text": content.append( - anthropic.types.ToolUseBlock(id=current_block.id, - input=json.loads(current_text), - name=current_block.name, - type="tool_use") + anthropic.types.TextBlock( + citations=None, text=current_text, type="text" + ) ) - response_list.append(content) - if isinstance(event, anthropic.types.RawMessageStopEvent) and event.message.stop_reason == "tool_use": + elif current_block.type == "tool_use": + content.append( + anthropic.types.ToolUseBlock( + id=current_block.id, + input=json.loads(current_text), + name=current_block.name, + type="tool_use", + ) + ) + response_list.append(content) + if ( + isinstance(event, anthropic.types.RawMessageStopEvent) + and event.message.stop_reason == "tool_use" + ): tool_call_params = json.loads(json_data) tool_call_result = self.get_weather(tool_call_params["location"]) messages.append({"role": "assistant", "content": content}) @@ -148,21 +165,25 @@ class WeatherAgent: def get_weather(self, location: str): """Get the current weather in a given location using latitude and longitude.""" - response = f'''Weather in {location}: + response = f"""Weather in {location}: Good morning! Expect overcast skies with intermittent showers throughout the day. Temperatures will range from a cool 15°C in the early hours to around 19°C by mid-afternoon. Light winds from the northeast at about 10 km/h will keep conditions mild. It might be a good idea to carry an umbrella if you’re heading out. Stay dry and have a great day! - ''' + """ return response -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") -async def test_response_with_toolcall( - context, explorer_api_url, proxy_url + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("push_to_explorer", [False, True]) +async def test_response_with_tool_call( + context, explorer_api_url, proxy_url, push_to_explorer ): """Test the chat completion without streaming for the weather agent.""" - - weather_agent = WeatherAgent(proxy_url) + + weather_agent = WeatherAgent(proxy_url, push_to_explorer) query = "Tell me the weather for New York" @@ -183,38 +204,46 @@ async def test_response_with_toolcall( assert city in response[1].content[0].text.lower() responses.append(response) - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" - ) - traces = await traces_response.json() - trace = traces[-1] - trace_id = trace["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() - trace_messages = trace["messages"] + if push_to_explorer: + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" + ) + traces = await traces_response.json() + trace = traces[-1] + trace_id = trace["id"] + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + trace_messages = trace["messages"] - assert trace_messages[0]["role"] == "user" - assert trace_messages[0]["content"] == query - assert trace_messages[1]["role"] == "assistant" - assert city in trace_messages[1]["content"].lower() - assert trace_messages[2]["role"] == "assistant" - assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" - assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() - assert trace_messages[3]["role"] == "tool" - assert trace_messages[4]["role"] == "assistant" - assert city in trace_messages[4]["content"].lower() + assert trace_messages[0]["role"] == "user" + assert trace_messages[0]["content"] == query + assert trace_messages[1]["role"] == "assistant" + assert city in trace_messages[1]["content"].lower() + assert trace_messages[2]["role"] == "assistant" + assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" + assert ( + city + in trace_messages[2]["tool_calls"][0]["function"]["arguments"][ + "location" + ].lower() + ) + assert trace_messages[3]["role"] == "tool" + assert trace_messages[4]["role"] == "assistant" + assert city in trace_messages[4]["content"].lower() - -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") -async def test_streaming_response_with_toolcall( - context, explorer_api_url, proxy_url +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("push_to_explorer", [False, True]) +async def test_streaming_response_with_tool_call( + context, explorer_api_url, proxy_url, push_to_explorer ): """Test the chat completion with streaming for the weather agent.""" - weather_agent = WeatherAgent(proxy_url) + weather_agent = WeatherAgent(proxy_url, push_to_explorer) query = "Tell me the weather for New York" city = "new york" @@ -226,104 +255,109 @@ async def test_streaming_response_with_toolcall( assert response[0][1].type == "tool_use" assert response[0][1].name == "get_weather" assert city in response[0][1].input["location"].lower() - + assert response[1][0].type == "text" assert city in response[1][0].text.lower() - - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" - ) - traces = await traces_response.json() - trace = traces[-1] - trace_id = trace["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() - trace_messages = trace["messages"] - assert trace_messages[0]["role"] == "user" - assert trace_messages[0]["content"] == query - assert trace_messages[1]["role"] == "assistant" - assert city in trace_messages[1]["content"].lower() - assert trace_messages[2]["role"] == "assistant" - assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" - assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() - assert trace_messages[3]["role"] == "tool" - assert trace_messages[4]["role"] == "assistant" - assert city in trace_messages[4]["content"].lower() + if push_to_explorer: + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" + ) + traces = await traces_response.json() + + trace = traces[-1] + trace_id = trace["id"] + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + trace_messages = trace["messages"] + assert trace_messages[0]["role"] == "user" + assert trace_messages[0]["content"] == query + assert trace_messages[1]["role"] == "assistant" + assert city in trace_messages[1]["content"].lower() + assert trace_messages[2]["role"] == "assistant" + assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" + assert ( + city + in trace_messages[2]["tool_calls"][0]["function"]["arguments"][ + "location" + ].lower() + ) + assert trace_messages[3]["role"] == "tool" + assert trace_messages[4]["role"] == "assistant" + assert city in trace_messages[4]["content"].lower() -async def test_response_with_toolcall_with_image( - context, explorer_api_url, proxy_url +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("push_to_explorer", [False, True]) +async def test_response_with_tool_call_with_image( + context, explorer_api_url, proxy_url, push_to_explorer ): - weatherAgent = WeatherAgent(proxy_url) + """Test the chat completion with image for the weather agent.""" + weather_agent = WeatherAgent(proxy_url, push_to_explorer) - image_path1 = Path(__file__).parent.parent / "images" / "new-york.jpeg" - image_path2 = Path(__file__).parent.parent / "images" / "two-cats.png" + image_path = Path(__file__).parent.parent / "images" / "new-york.jpeg" - image1 = open(image_path1, "rb") - image2 = open(image_path2, "rb") - base64_image1 = base64.b64encode(image1.read()).decode("utf-8") - base64_image2 = base64.b64encode(image2.read()).decode("utf-8") - query = "get the weather in the city of these images" - city = "new york" - messages = [ - { - "role": "user", "content": [ - - { - "type": "text", - "text": query, - }, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": base64_image1, - } - }, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": base64_image2, - } - }, - ] - } - ] - response = weatherAgent.get_response(messages) - assert response is not None - assert response[0].role == "assistant" - assert response[0].stop_reason == "tool_use" - assert response[0].content[0].type == "text" - assert response[0].content[1].type == "tool_use" - assert city in response[0].content[1].input["location"].lower() + with image_path.open("rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode("utf-8") + query = "get the weather in the city of this image" + city = "new york" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": base64_image, + }, + }, + ], + } + ] + response = weather_agent.get_response(messages) + assert response is not None + assert response[0].role == "assistant" + assert response[0].stop_reason == "tool_use" + assert response[0].content[0].type == "text" + assert response[0].content[1].type == "tool_use" + assert city in response[0].content[1].input["location"].lower() - assert response[1].role == "assistant" - assert response[1].stop_reason == "end_turn" + assert response[1].role == "assistant" + assert response[1].stop_reason == "end_turn" - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weatherAgent.dataset_name}/traces" - ) - traces = await traces_response.json() + if push_to_explorer: + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces" + ) + traces = await traces_response.json() - trace = traces[-1] - trace_id = trace["id"] - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() - trace_messages = trace["messages"] - assert trace_messages[0]["role"] == "user" - assert trace_messages[1]["role"] == "assistant" - assert city in trace_messages[1]["content"].lower() - assert trace_messages[2]["role"] == "assistant" - assert trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" - assert city in trace_messages[2]["tool_calls"][0]["function"]["arguments"]["location"].lower() - assert trace_messages[3]["role"] == "tool" - assert trace_messages[4]["role"] == "assistant" + trace = traces[-1] + trace_id = trace["id"] + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + trace_messages = trace["messages"] + assert trace_messages[0]["role"] == "user" + assert trace_messages[1]["role"] == "assistant" + assert city in trace_messages[1]["content"].lower() + assert trace_messages[2]["role"] == "assistant" + assert ( + trace_messages[2]["tool_calls"][0]["function"]["name"] == "get_weather" + ) + assert ( + city + in trace_messages[2]["tool_calls"][0]["function"]["arguments"][ + "location" + ].lower() + ) + assert trace_messages[3]["role"] == "tool" + assert trace_messages[4]["role"] == "assistant" diff --git a/tests/anthropic/test_anthropic_without_tool_call.py b/tests/anthropic/test_anthropic_without_tool_call.py index 2748158..cb991dc 100644 --- a/tests/anthropic/test_anthropic_without_tool_call.py +++ b/tests/anthropic/test_anthropic_without_tool_call.py @@ -1,106 +1,116 @@ -import anthropic -import os -from httpx import Client +"""Tests for the Anthropic API without tool call.""" + import datetime -import pytest +import os import sys + +import anthropic +import pytest +from httpx import Client + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from util import * # needed for pytest fixtures -pytest_plugins = ("pytest_asyncio") -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") -async def test_response_without_toolcall( - context, explorer_api_url,proxy_url +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("push_to_explorer", [False, True]) +async def test_response_without_tool_call( + context, explorer_api_url, proxy_url, push_to_explorer ): - dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S")) - invariant_api_key = os.environ.get("INVARIANT_API_KEY","None") + """Test the Anthropic proxy without tool calling.""" + dataset_name = "claude_streaming_response_without_tool_call_test" + str( + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + ) + invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic", - ) - + http_client=Client( + headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, + ), + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/anthropic", + ) + cities = ["zurich", "new york", "london"] queries = [ "Can you introduce Zurich, Switzerland within 200 words?", "Tell me the history of New York within 100 words?", - "How's the weather in London next week?" + "How's the weather in London next week?", ] # Process each query responses = [] for query in queries: response = client.messages.create( - model="claude-3-5-sonnet-20241022", - max_tokens=1024, - messages=[{"role": "user", "content": query}], - ) + model="claude-3-5-sonnet-20241022", + max_tokens=1024, + messages=[{"role": "user", "content": query}], + ) response_text = response.content[0].text responses.append(response_text) assert response_text is not None assert cities[queries.index(query)] in response_text.lower() - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == len(queries) - - for index,trace in enumerate(traces): - trace_id = trace["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" + if push_to_explorer: + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" ) - trace = await trace_response.json() - assert trace["messages"] == [ - { - "role": "user", - "content": queries[index] - }, - { - "role": "assistant", - "content": responses[index] - } - ] + traces = await traces_response.json() + assert len(traces) == len(queries) -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") -async def test_streaming_response_without_toolcall( - context, - explorer_api_url, - proxy_url - ): + for index, trace in enumerate(traces): + trace_id = trace["id"] + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + assert trace["messages"] == [ + {"role": "user", "content": queries[index]}, + {"role": "assistant", "content": responses[index]}, + ] - dataset_name = "claude_streaming_response_without_toolcall_test" + str(datetime.datetime.now().strftime("%Y%m%d%H%M%S")) - invariant_api_key = os.environ.get("INVARIANT_API_KEY","None") + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize("push_to_explorer", [False, True]) +async def test_streaming_response_without_tool_call( + context, explorer_api_url, proxy_url, push_to_explorer +): + """Test the Anthropic proxy without tool calling.""" + dataset_name = "claude_streaming_response_without_tool_call_test" + str( + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + ) + invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None") client = anthropic.Anthropic( - http_client=Client( - headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, - ), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic", - ) - + http_client=Client( + headers={"Invariant-Authorization": f"Bearer {invariant_api_key}"}, + ), + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/anthropic", + ) + cities = ["zurich", "new york", "london"] queries = [ "Can you introduce Zurich, Switzerland within 200 words?", "Tell me the history of New York within 100 words?", - "How's the weather in London next week?" + "How's the weather in London next week?", ] # Process each query responses = [] - for index,query in enumerate(queries): - messages = [ - { - "role": "user", - "content": query - } - ] + for index, query in enumerate(queries): + messages = [{"role": "user", "content": query}] response_text = "" - + with client.messages.stream( model="claude-3-5-sonnet-20241022", max_tokens=1024, @@ -113,26 +123,21 @@ async def test_streaming_response_without_toolcall( assert response_text is not None assert cities[queries.index(query)] in response_text.lower() - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == len(queries) - - for index,trace in enumerate(traces): - trace_id = trace["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" + if push_to_explorer: + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" ) - trace = await trace_response.json() - assert trace["messages"] == [ - { - "role": "user", - "content": queries[index] - }, - { - "role": "assistant", - "content": responses[index] - } - ] \ No newline at end of file + traces = await traces_response.json() + assert len(traces) == len(queries) + + for index, trace in enumerate(traces): + trace_id = trace["id"] + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + assert trace["messages"] == [ + {"role": "user", "content": queries[index]}, + {"role": "assistant", "content": responses[index]}, + ] diff --git a/tests/open_ai/test_chat_with_tool_call.py b/tests/open_ai/test_chat_with_tool_call.py index 72f7f8f..12999f4 100644 --- a/tests/open_ai/test_chat_with_tool_call.py +++ b/tests/open_ai/test_chat_with_tool_call.py @@ -7,6 +7,7 @@ import uuid import pytest from httpx import Client + # add tests folder (parent) to sys.path from openai import OpenAI @@ -18,8 +19,9 @@ pytest_plugins = ("pytest_asyncio",) @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize("push_to_explorer", [False, True]) async def test_chat_completion_with_tool_call_without_streaming( - context, explorer_api_url, proxy_url + context, explorer_api_url, proxy_url, push_to_explorer ): """ Test the chat completions proxy calls with tool calling and response processing @@ -33,7 +35,9 @@ async def test_chat_completion_with_tool_call_without_streaming( "Invariant-Authorization": "Bearer " }, # This key is not used for local tests ), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai", + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/openai", ) chat_response = client.chat.completions.create( @@ -96,36 +100,38 @@ async def test_chat_completion_with_tool_call_without_streaming( ) assert "15°C" in chat_response_final.choices[0].message.content - # Fetch the trace ids for the dataset - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == 1 - trace_id = traces[0]["id"] + if push_to_explorer: + # Fetch the trace ids for the dataset + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() - # Verify the trace messages - expected_messages = history + [ - { - "role": "assistant", - "content": chat_response_final.choices[0].message.content, - } - ] - expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads( - expected_messages[1]["tool_calls"][0]["function"]["arguments"] - ) - assert trace["messages"] == expected_messages + # Verify the trace messages + expected_messages = history + [ + { + "role": "assistant", + "content": chat_response_final.choices[0].message.content, + } + ] + expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads( + expected_messages[1]["tool_calls"][0]["function"]["arguments"] + ) + assert trace["messages"] == expected_messages @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize("push_to_explorer", [False, True]) async def test_chat_completion_with_tool_call_with_streaming( - context, explorer_api_url, proxy_url + context, explorer_api_url, proxy_url, push_to_explorer ): """ Test the chat completions proxy calls with tool calling and response processing @@ -139,7 +145,9 @@ async def test_chat_completion_with_tool_call_with_streaming( "Invariant-Authorization": "Bearer " }, # This key is not used for local tests ), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai", + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/openai", ) chat_response = client.chat.completions.create( @@ -209,23 +217,24 @@ async def test_chat_completion_with_tool_call_with_streaming( if chunk.choices and chunk.choices[0].delta.content: final_response["content"] += chunk.choices[0].delta.content - # Fetch the trace ids for the dataset - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == 1 - trace_id = traces[0]["id"] + if push_to_explorer: + # Fetch the trace ids for the dataset + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() - # Verify the trace messages - expected_messages = history + [final_response] - expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads( - expected_messages[1]["tool_calls"][0]["function"]["arguments"] - ) - assert trace["messages"] == expected_messages + # Verify the trace messages + expected_messages = history + [final_response] + expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads( + expected_messages[1]["tool_calls"][0]["function"]["arguments"] + ) + assert trace["messages"] == expected_messages diff --git a/tests/open_ai/test_chat_without_tool_calls.py b/tests/open_ai/test_chat_without_tool_calls.py index d0e90c1..692edba 100644 --- a/tests/open_ai/test_chat_without_tool_calls.py +++ b/tests/open_ai/test_chat_without_tool_calls.py @@ -21,8 +21,13 @@ pytest_plugins = ("pytest_asyncio",) @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") -@pytest.mark.parametrize("do_stream", [True, False]) -async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream): +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_chat_completion( + context, explorer_api_url, proxy_url, do_stream, push_to_explorer +): """Test the chat completions proxy calls without tool calling.""" dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4()) @@ -32,7 +37,9 @@ async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream): "Invariant-Authorization": "Bearer " }, # This key is not used for local tests ), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai", + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/openai", ) chat_response = client.chat.completions.create( @@ -53,35 +60,39 @@ async def test_chat_completion(context, explorer_api_url, proxy_url, do_stream): assert "PARIS" in full_response.upper() expected_assistant_message = full_response - # Fetch the trace ids for the dataset - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == 1 - trace_id = traces[0]["id"] + if push_to_explorer: + # Fetch the trace ids for the dataset + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() - # Verify the trace messages - assert trace["messages"] == [ - { - "role": "user", - "content": "What is the capital of France?", - }, - { - "role": "assistant", - "content": expected_assistant_message, - }, - ] + # Verify the trace messages + assert trace["messages"] == [ + { + "role": "user", + "content": "What is the capital of France?", + }, + { + "role": "assistant", + "content": expected_assistant_message, + }, + ] @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") -async def test_chat_completion_with_image(context, explorer_api_url, proxy_url): +@pytest.mark.parametrize("push_to_explorer", [True, False]) +async def test_chat_completion_with_image( + context, explorer_api_url, proxy_url, push_to_explorer +): """Test the chat completions proxy works with image.""" dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4()) @@ -91,7 +102,9 @@ async def test_chat_completion_with_image(context, explorer_api_url, proxy_url): "Invariant-Authorization": "Bearer " }, # This key is not used for local tests ), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai", + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai" + if push_to_explorer + else f"{proxy_url}/api/v1/proxy/openai", ) image_path = Path(__file__).parent.parent / "images" / "two-cats.png" with image_path.open("rb") as image_file: @@ -121,37 +134,43 @@ async def test_chat_completion_with_image(context, explorer_api_url, proxy_url): assert "TWO" in chat_response.choices[0].message.content.upper() - # Fetch the trace ids for the dataset - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == 1 - trace_id = traces[0]["id"] + if push_to_explorer: + # Fetch the trace ids for the dataset + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] - # Fetch the trace - trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await trace_response.json() + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() - # Verify the trace messages - assert trace["messages"] == [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How many cats are there in this image?"}, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64," + base64_image}, - }, - ], - }, - { - "role": "assistant", - "content": chat_response.choices[0].message.content, - }, - ] + # Verify the trace messages + assert trace["messages"] == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "How many cats are there in this image?", + }, + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64," + base64_image + }, + }, + ], + }, + { + "role": "assistant", + "content": chat_response.choices[0].message.content, + }, + ] @pytest.mark.skip(reason="Skipping this test: OpenAI error scenario")