diff --git a/tests/anthropic/test_anthropic_header_with_invariant_key.py b/tests/anthropic/test_anthropic_header_with_invariant_key.py index bbe5bf2..d5304c9 100644 --- a/tests/anthropic/test_anthropic_header_with_invariant_key.py +++ b/tests/anthropic/test_anthropic_header_with_invariant_key.py @@ -19,9 +19,8 @@ pytest_plugins = ("pytest_asyncio",) @pytest.mark.skipif( not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" ) -@pytest.mark.parametrize("push_to_explorer", [False, True]) -async def test_proxy_with_invariant_key_in_anthropic_key( - context, proxy_url, explorer_api_url, push_to_explorer +async def test_proxy_with_invariant_key_in_anthropic_key_header( + context, proxy_url, explorer_api_url ): """Test the Anthropic proxy with Invariant key in the Anthropic key""" anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") @@ -37,9 +36,7 @@ async def test_proxy_with_invariant_key_in_anthropic_key( ): client = anthropic.Anthropic( http_client=Client(), - base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic" - if push_to_explorer - else f"{proxy_url}/api/v1/proxy/anthropic", + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic", ) response = client.messages.create( model="claude-3-5-sonnet-20241022", @@ -56,22 +53,21 @@ async def test_proxy_with_invariant_key_in_anthropic_key( response_text = response.content[0].text assert "zurich" in response_text.lower() - if push_to_explorer: - traces_response = await context.request.get( - f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" - ) - traces = await traces_response.json() - assert len(traces) == 1 + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 - trace_id = traces[0]["id"] - get_trace_response = await context.request.get( - f"{explorer_api_url}/api/v1/trace/{trace_id}" - ) - trace = await get_trace_response.json() - assert trace["messages"] == [ - { - "role": "user", - "content": "Give me an introduction to Zurich, Switzerland within 200 words.", - }, - {"role": "assistant", "content": response_text}, - ] + trace_id = traces[0]["id"] + get_trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await get_trace_response.json() + assert trace["messages"] == [ + { + "role": "user", + "content": "Give me an introduction to Zurich, Switzerland within 200 words.", + }, + {"role": "assistant", "content": response_text}, + ] diff --git a/tests/open_ai/test_chat_without_tool_calls.py b/tests/open_ai/test_chat_without_tool_calls.py index 692edba..c3355d4 100644 --- a/tests/open_ai/test_chat_without_tool_calls.py +++ b/tests/open_ai/test_chat_without_tool_calls.py @@ -5,6 +5,7 @@ import os import sys import uuid from pathlib import Path +from unittest.mock import patch import pytest from httpx import Client @@ -173,6 +174,59 @@ async def test_chat_completion_with_image( ] +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +async def test_chat_completion_with_invariant_key_in_openai_key_header( + context, explorer_api_url, proxy_url +): + """Test the chat completions proxy calls with the Invariant API Key in the OpenAI Key header.""" + dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4()) + openai_api_key = os.getenv("OPENAI_API_KEY") + with patch.dict( + os.environ, + {"OPENAI_API_KEY": openai_api_key + "|invariant-auth: "}, + ): + client = OpenAI( + http_client=Client(), + base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai", + ) + + chat_response = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "What is the capital of France?"}], + stream=False, + ) + + # Verify the chat response + assert "PARIS" in chat_response.choices[0].message.content.upper() + expected_assistant_message = chat_response.choices[0].message.content + + # Fetch the trace ids for the dataset + traces_response = await context.request.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces" + ) + traces = await traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = await context.request.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}" + ) + trace = await trace_response.json() + + # Verify the trace messages + assert trace["messages"] == [ + { + "role": "user", + "content": "What is the capital of France?", + }, + { + "role": "assistant", + "content": expected_assistant_message, + }, + ] + + @pytest.mark.skip(reason="Skipping this test: OpenAI error scenario") @pytest.mark.parametrize("do_stream", [True, False]) async def test_chat_completion_with_openai_exception(proxy_url, do_stream):