mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 07:19:42 +02:00
Add test for OpenAI so that we verify that the Invariant API Key can be passed inside the OpenAI Key header.
This commit is contained in:
@@ -19,9 +19,8 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_proxy_with_invariant_key_in_anthropic_key(
|
||||
context, proxy_url, explorer_api_url, push_to_explorer
|
||||
async def test_proxy_with_invariant_key_in_anthropic_key_header(
|
||||
context, proxy_url, explorer_api_url
|
||||
):
|
||||
"""Test the Anthropic proxy with Invariant key in the Anthropic key"""
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
@@ -37,9 +36,7 @@ async def test_proxy_with_invariant_key_in_anthropic_key(
|
||||
):
|
||||
client = anthropic.Anthropic(
|
||||
http_client=Client(),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic"
|
||||
if push_to_explorer
|
||||
else f"{proxy_url}/api/v1/proxy/anthropic",
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/anthropic",
|
||||
)
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
@@ -56,22 +53,21 @@ async def test_proxy_with_invariant_key_in_anthropic_key(
|
||||
response_text = response.content[0].text
|
||||
assert "zurich" in response_text.lower()
|
||||
|
||||
if push_to_explorer:
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
|
||||
trace_id = traces[0]["id"]
|
||||
get_trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await get_trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words.",
|
||||
},
|
||||
{"role": "assistant", "content": response_text},
|
||||
]
|
||||
trace_id = traces[0]["id"]
|
||||
get_trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await get_trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me an introduction to Zurich, Switzerland within 200 words.",
|
||||
},
|
||||
{"role": "assistant", "content": response_text},
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import Client
|
||||
@@ -173,6 +174,59 @@ async def test_chat_completion_with_image(
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
||||
context, explorer_api_url, proxy_url
|
||||
):
|
||||
"""Test the chat completions proxy calls with the Invariant API Key in the OpenAI Key header."""
|
||||
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"OPENAI_API_KEY": openai_api_key + "|invariant-auth: <not needed for test>"},
|
||||
):
|
||||
client = OpenAI(
|
||||
http_client=Client(),
|
||||
base_url=f"{proxy_url}/api/v1/proxy/{dataset_name}/openai",
|
||||
)
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify the chat response
|
||||
assert "PARIS" in chat_response.choices[0].message.content.upper()
|
||||
expected_assistant_message = chat_response.choices[0].message.content
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": expected_assistant_message,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test: OpenAI error scenario")
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_chat_completion_with_openai_exception(proxy_url, do_stream):
|
||||
|
||||
Reference in New Issue
Block a user