mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-18 21:28:07 +02:00
268 lines
9.2 KiB
Python
268 lines
9.2 KiB
Python
"""Test the chat completions gateway calls without tool calling."""
|
|
|
|
import base64
|
|
import os
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
# Add integration folder (parent) to sys.path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from utils import get_open_ai_client
|
|
|
|
import pytest
|
|
import requests
|
|
from httpx import Client
|
|
from openai import NotFoundError, OpenAI
|
|
|
|
# Pytest plugins
|
|
pytest_plugins = ("pytest_asyncio",)
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
|
@pytest.mark.parametrize(
|
|
"do_stream, push_to_explorer",
|
|
[(True, True), (True, False), (False, True), (False, False)],
|
|
)
|
|
async def test_chat_completion(
|
|
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
|
):
|
|
"""Test the chat completions gateway calls without tool calling."""
|
|
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
|
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
|
|
|
chat_response = client.chat.completions.create(
|
|
model="gpt-4o",
|
|
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
stream=do_stream,
|
|
)
|
|
|
|
# Verify the chat response
|
|
if not do_stream:
|
|
assert "PARIS" in chat_response.choices[0].message.content.upper()
|
|
expected_assistant_message = chat_response.choices[0].message.content
|
|
else:
|
|
full_response = ""
|
|
for chunk in chat_response:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
full_response += chunk.choices[0].delta.content
|
|
assert "PARIS" in full_response.upper()
|
|
expected_assistant_message = full_response
|
|
|
|
if push_to_explorer:
|
|
# Wait for the trace to be saved
|
|
# This is needed because the trace is saved asynchronously
|
|
time.sleep(2)
|
|
# Fetch the trace ids for the dataset
|
|
traces_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
|
timeout=5,
|
|
)
|
|
traces = traces_response.json()
|
|
assert len(traces) == 1
|
|
trace_id = traces[0]["id"]
|
|
|
|
# Fetch the trace
|
|
trace_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
|
timeout=5,
|
|
)
|
|
trace = trace_response.json()
|
|
|
|
for message in trace["messages"]:
|
|
message.pop("annotations", None)
|
|
|
|
# Verify the trace messages
|
|
assert trace["messages"] == [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the capital of France?",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": expected_assistant_message,
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
|
@pytest.mark.parametrize("push_to_explorer", [True, False])
|
|
async def test_chat_completion_with_image(
|
|
explorer_api_url, gateway_url, push_to_explorer
|
|
):
|
|
"""Test the chat completions gateway works with image."""
|
|
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
|
client = get_open_ai_client(gateway_url, push_to_explorer, dataset_name)
|
|
|
|
image_path = Path(__file__).parent.parent / "resources" / "images" / "two-cats.png"
|
|
with image_path.open("rb") as image_file:
|
|
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
|
|
chat_response = client.chat.completions.create(
|
|
model="gpt-4o",
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "How many cats are there in this image?",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{base64_image}"
|
|
},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
max_tokens=100,
|
|
)
|
|
|
|
assert (
|
|
"TWO" in chat_response.choices[0].message.content.upper()
|
|
or "2" in chat_response.choices[0].message.content
|
|
)
|
|
|
|
if push_to_explorer:
|
|
# Wait for the trace to be saved
|
|
# This is needed because the trace is saved asynchronously
|
|
time.sleep(2)
|
|
# Fetch the trace ids for the dataset
|
|
traces_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
|
timeout=5,
|
|
)
|
|
traces = traces_response.json()
|
|
assert len(traces) == 1
|
|
trace_id = traces[0]["id"]
|
|
|
|
# Fetch the trace
|
|
trace_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
|
timeout=5,
|
|
)
|
|
trace = trace_response.json()
|
|
|
|
for message in trace["messages"]:
|
|
message.pop("annotations", None)
|
|
|
|
# Verify the trace messages
|
|
assert len(trace["messages"]) == 2
|
|
assert trace["messages"][0]["content"][0]["type"] == "text"
|
|
assert (
|
|
trace["messages"][0]["content"][0]["text"]
|
|
== "How many cats are there in this image?"
|
|
)
|
|
assert trace["messages"][0]["content"][1]["type"] == "image_url"
|
|
assert trace["messages"][1] == {
|
|
"role": "assistant",
|
|
"content": chat_response.choices[0].message.content,
|
|
}
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
|
async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
|
explorer_api_url, gateway_url
|
|
):
|
|
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
|
|
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
invariant_key_suffix = f";invariant-auth={os.getenv('INVARIANT_API_KEY')}"
|
|
with patch.dict(
|
|
os.environ,
|
|
{"OPENAI_API_KEY": openai_api_key + invariant_key_suffix},
|
|
):
|
|
client = OpenAI(
|
|
http_client=Client(),
|
|
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
|
|
)
|
|
|
|
chat_response = client.chat.completions.create(
|
|
model="gpt-4o",
|
|
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
stream=False,
|
|
)
|
|
|
|
# Verify the chat response
|
|
assert "PARIS" in chat_response.choices[0].message.content.upper()
|
|
expected_assistant_message = chat_response.choices[0].message.content
|
|
|
|
# Wait for the trace to be saved
|
|
# This is needed because the trace is saved asynchronously
|
|
time.sleep(2)
|
|
|
|
# Fetch the trace ids for the dataset
|
|
traces_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
|
timeout=5,
|
|
)
|
|
traces = traces_response.json()
|
|
assert len(traces) == 1
|
|
trace_id = traces[0]["id"]
|
|
|
|
# Fetch the trace
|
|
trace_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
|
timeout=5,
|
|
)
|
|
trace = trace_response.json()
|
|
|
|
for message in trace["messages"]:
|
|
message.pop("annotations", None)
|
|
|
|
# Verify the trace messages
|
|
assert trace["messages"] == [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the capital of France?",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": expected_assistant_message,
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.mark.skip(reason="Skipping this test: OpenAI error scenario")
|
|
@pytest.mark.parametrize("do_stream", [True, False])
|
|
async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
|
|
"""Test the chat completions gateway call when OpenAI API fails."""
|
|
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
|
client = OpenAI(
|
|
http_client=Client(
|
|
headers={
|
|
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
|
}, # This key is not used for local tests
|
|
),
|
|
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
|
|
)
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
_ = client.chat.completions.create(
|
|
model="gpt-4-vision-preview", # This model is not available so we get a 404 error
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "How many cats are there in this image?",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "data:image/png;base64," + "01234"},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
stream=do_stream,
|
|
)
|
|
|
|
assert exc_info.errisinstance(NotFoundError)
|
|
assert exc_info.value.status_code == 404
|