mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-15 04:30:25 +02:00
Remove unnecessary playwright dependency for tests.
This commit is contained in:
@@ -292,10 +292,24 @@ This will launch Gateway at [http://localhost:8005/api/v1/gateway/](http://local
|
||||
|
||||
By default Gateway points to the public Explorer instance at `explorer.invariantlabs.ai`. To point it to your local Explorer instance, modify the `INVARIANT_API_URL` value inside `.env`. Follow instructions in `.env` on how to point to the local instance.
|
||||
|
||||
### **Run Tests**
|
||||
### **Run Unit Tests**
|
||||
|
||||
To run tests, execute:
|
||||
To run the unit tests, execute:
|
||||
|
||||
```bash
|
||||
./run.sh tests
|
||||
bash run.sh unit-tests
|
||||
```
|
||||
|
||||
### **Run Integration Tests**
|
||||
|
||||
To run the integration tests, execute:
|
||||
|
||||
```bash
|
||||
bash run.sh integration-tests
|
||||
```
|
||||
|
||||
To run a subset of the integration tests, execute:
|
||||
|
||||
```bash
|
||||
bash run.sh integration-tests open_ai/test_chat_with_tool_call.py
|
||||
```
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM mcr.microsoft.com/playwright/python:v1.50.0-noble
|
||||
FROM python:3.11-slim
|
||||
|
||||
RUN mkdir -p /tests
|
||||
COPY ./integration/requirements.txt /tests/requirements.txt
|
||||
@@ -6,4 +6,4 @@ WORKDIR /tests
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENTRYPOINT ["pytest", "--capture=tee-sys", "--tracing", "off", "--junit-xml=/tests/results/test-results-all.xml", "-s", "-vv"]
|
||||
ENTRYPOINT ["pytest", "--capture=tee-sys", "--junit-xml=/tests/results/test-results-all.xml", "-s", "-vv"]
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Test the Anthropic gateway with Invariant key in the ANTHROPIC_API_KEY."""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
# Add integration folder (parent) to sys.path
|
||||
@@ -11,10 +11,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@@ -23,13 +22,11 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
async def test_gateway_with_invariant_key_in_anthropic_key_header(
|
||||
context, gateway_url, explorer_api_url
|
||||
gateway_url, explorer_api_url
|
||||
):
|
||||
"""Test the Anthropic gateway with Invariant key in the Anthropic key"""
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
dataset_name = "claude_header_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
@@ -60,17 +57,18 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
|
||||
trace_id = traces[0]["id"]
|
||||
get_trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
get_trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
|
||||
)
|
||||
trace = await get_trace_response.json()
|
||||
trace = get_trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Test the Anthropic messages API with tool call for the weather agent."""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -14,10 +14,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@@ -26,9 +25,7 @@ class WeatherAgent:
|
||||
"""Weather agent to get the current weather in a given location."""
|
||||
|
||||
def __init__(self, gateway_url, push_to_explorer):
|
||||
self.dataset_name = "claude_weather_agent_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
self.client = anthropic.Anthropic(
|
||||
http_client=Client(
|
||||
@@ -183,9 +180,7 @@ class WeatherAgent:
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_response_with_tool_call(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_explorer):
|
||||
"""Test the chat completion without streaming for the weather agent."""
|
||||
|
||||
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
|
||||
@@ -213,17 +208,18 @@ async def test_response_with_tool_call(
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
@@ -248,7 +244,7 @@ async def test_response_with_tool_call(
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_streaming_response_with_tool_call(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""Test the chat completion with streaming for the weather agent."""
|
||||
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
|
||||
@@ -271,18 +267,19 @@ async def test_streaming_response_with_tool_call(
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[0]["content"] == query
|
||||
@@ -306,7 +303,7 @@ async def test_streaming_response_with_tool_call(
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_response_with_tool_call_with_image(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""Test the chat completion with image for the weather agent."""
|
||||
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
|
||||
@@ -348,17 +345,18 @@ async def test_response_with_tool_call_with_image(
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
|
||||
trace = traces[-1]
|
||||
trace_id = trace["id"]
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
trace_messages = trace["messages"]
|
||||
assert trace_messages[0]["role"] == "user"
|
||||
assert trace_messages[1]["role"] == "assistant"
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
"""Tests for the Anthropic API without tool call."""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@@ -23,12 +22,10 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_response_without_tool_call(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""Test the Anthropic gateway without tool calling."""
|
||||
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
@@ -64,19 +61,21 @@ async def test_response_without_tool_call(
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
for index, trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{"role": "user", "content": queries[index]},
|
||||
{"role": "assistant", "content": responses[index]},
|
||||
@@ -88,12 +87,10 @@ async def test_response_without_tool_call(
|
||||
)
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_streaming_response_without_tool_call(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""Test the Anthropic gateway without tool calling."""
|
||||
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
)
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
@@ -133,19 +130,21 @@ async def test_streaming_response_without_tool_call(
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == len(queries)
|
||||
|
||||
for index, trace in enumerate(traces):
|
||||
trace_id = trace["id"]
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
assert trace["messages"] == [
|
||||
{"role": "user", "content": queries[index]},
|
||||
{"role": "assistant", "content": responses[index]},
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Util functions for tests"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def gateway_url():
|
||||
"""Get the gateway URL from the environment variable"""
|
||||
if "INVARIANT_GATEWAY_API_URL" in os.environ:
|
||||
return os.environ["INVARIANT_GATEWAY_API_URL"]
|
||||
raise ValueError("Please set the INVARIANT_GATEWAY_API_URL environment variable")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def explorer_api_url():
|
||||
"""Get the explorer API URL from the environment variable"""
|
||||
if "INVARIANT_API_URL" in os.environ:
|
||||
return os.environ["INVARIANT_API_URL"]
|
||||
raise ValueError("Please set the INVARIANT_API_URL environment variable")
|
||||
@@ -9,14 +9,14 @@ import uuid
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
def set_light_values(brightness: int, color_temp: str) -> dict[str, int | str]:
|
||||
"""Set the brightness and color temperature of a room light. (mock API).
|
||||
|
||||
@@ -49,32 +49,35 @@ SET_LIGHT_VALUES_TOOL_CALL = {
|
||||
}
|
||||
|
||||
|
||||
async def _verify_trace_from_explorer(
|
||||
context, explorer_api_url, dataset_name, expected_final_assistant_message
|
||||
def _verify_trace_from_explorer(
|
||||
explorer_api_url, dataset_name, expected_final_assistant_message
|
||||
) -> None:
|
||||
# Fetch the trace ids for the dataset.
|
||||
# There will be 2 traces - the first will contain the system instruction, user prompt
|
||||
# and the assistant tool call.
|
||||
# The second will contain the system instruction, user prompt, the assistant tool call,
|
||||
# the tool response and the assistant response.
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 2
|
||||
trace_id_1 = traces[0]["id"]
|
||||
trace_id_2 = traces[1]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response_1 = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id_1}"
|
||||
trace_response_1 = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id_1}",
|
||||
timeout=5,
|
||||
)
|
||||
trace_1 = await trace_response_1.json()
|
||||
trace_1 = trace_response_1.json()
|
||||
|
||||
trace_response_2 = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id_2}"
|
||||
trace_response_2 = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id_2}",
|
||||
timeout=5,
|
||||
)
|
||||
trace_2 = await trace_response_2.json()
|
||||
trace_2 = trace_response_2.json()
|
||||
|
||||
# Verify the trace messages
|
||||
assert trace_1["messages"] == [
|
||||
@@ -133,13 +136,13 @@ async def _verify_trace_from_explorer(
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_generate_content_with_tool_call(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer, do_stream
|
||||
explorer_api_url, gateway_url, push_to_explorer, do_stream
|
||||
):
|
||||
"""
|
||||
Test the generate content gateway calls with tool calling and response processing
|
||||
without streaming.
|
||||
"""
|
||||
dataset_name = "test-dataset-gemini-tool-call-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
@@ -190,6 +193,6 @@ async def test_generate_content_with_tool_call(
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
await _verify_trace_from_explorer(
|
||||
context, explorer_api_url, dataset_name, expected_final_assistant_message
|
||||
_verify_trace_from_explorer(
|
||||
explorer_api_url, dataset_name, expected_final_assistant_message
|
||||
)
|
||||
|
||||
@@ -12,10 +12,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import PIL.Image
|
||||
import requests
|
||||
from google import genai
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@@ -26,10 +25,10 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_generate_content(
|
||||
context, explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test the generate content gateway calls without tool calling."""
|
||||
dataset_name = "test-dataset-gemini-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
@@ -78,18 +77,19 @@ async def test_generate_content(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
@@ -111,10 +111,10 @@ async def test_generate_content(
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
@pytest.mark.parametrize("push_to_explorer", [True, False])
|
||||
async def test_generate_content_with_image(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""Test that generate content gateway calls work with image."""
|
||||
dataset_name = "test-dataset-gemini-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
@@ -147,18 +147,19 @@ async def test_generate_content_with_image(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
# Verify the trace messages
|
||||
assert len(trace["messages"]) == 2
|
||||
assert trace["messages"][0]["role"] == "user"
|
||||
@@ -175,10 +176,10 @@ async def test_generate_content_with_image(
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
context, explorer_api_url, gateway_url
|
||||
explorer_api_url, gateway_url
|
||||
):
|
||||
"""Test the generate content gateway calls with the Invariant API Key in the Gemini Key header."""
|
||||
dataset_name = "test-dataset-gemini-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
@@ -208,18 +209,20 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
assert trace["messages"] == [
|
||||
|
||||
@@ -10,11 +10,10 @@ import uuid
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
from openai import OpenAI
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@@ -22,13 +21,13 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_chat_completion_with_tool_call_without_streaming(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""
|
||||
Test the chat completions gateway calls with tool calling and response processing
|
||||
without streaming.
|
||||
"""
|
||||
dataset_name = "test-dataset-open-ai-tool-call-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
@@ -106,18 +105,20 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
for message in trace["messages"]:
|
||||
message.pop("annotations", None)
|
||||
@@ -138,13 +139,13 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("push_to_explorer", [False, True])
|
||||
async def test_chat_completion_with_tool_call_with_streaming(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""
|
||||
Test the chat completions gateway calls with tool calling and response processing
|
||||
while streaming.
|
||||
"""
|
||||
dataset_name = "test-dataset-open-ai-tool-call-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
@@ -229,18 +230,20 @@ async def test_chat_completion_with_tool_call_with_streaming(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
# Verify the trace messages
|
||||
expected_messages = history + [final_response]
|
||||
|
||||
@@ -12,11 +12,10 @@ from unittest.mock import patch
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from httpx import Client
|
||||
from openai import NotFoundError, OpenAI
|
||||
|
||||
from util import * # Needed for pytest fixtures
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@@ -27,10 +26,10 @@ pytest_plugins = ("pytest_asyncio",)
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_chat_completion(
|
||||
context, explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test the chat completions gateway calls without tool calling."""
|
||||
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
@@ -66,18 +65,20 @@ async def test_chat_completion(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
for message in trace["messages"]:
|
||||
message.pop("annotations", None)
|
||||
@@ -98,10 +99,10 @@ async def test_chat_completion(
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize("push_to_explorer", [True, False])
|
||||
async def test_chat_completion_with_image(
|
||||
context, explorer_api_url, gateway_url, push_to_explorer
|
||||
explorer_api_url, gateway_url, push_to_explorer
|
||||
):
|
||||
"""Test the chat completions gateway works with image."""
|
||||
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
@@ -149,18 +150,20 @@ async def test_chat_completion_with_image(
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
for message in trace["messages"]:
|
||||
message.pop("annotations", None)
|
||||
@@ -181,10 +184,10 @@ async def test_chat_completion_with_image(
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
||||
context, explorer_api_url, gateway_url
|
||||
explorer_api_url, gateway_url
|
||||
):
|
||||
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
|
||||
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
@@ -210,18 +213,20 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = await traces_response.json()
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = await context.request.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}"
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = await trace_response.json()
|
||||
trace = trace_response.json()
|
||||
|
||||
for message in trace["messages"]:
|
||||
message.pop("annotations", None)
|
||||
@@ -243,14 +248,14 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
|
||||
"""Test the chat completions gateway call when OpenAI API fails."""
|
||||
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": "Bearer <some-key>"
|
||||
}, # This key is not used for local tests
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{"test-dataset-open-ai-" + str(uuid.uuid4())}/openai",
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
|
||||
@@ -4,5 +4,4 @@ openai
|
||||
pillow
|
||||
pytest
|
||||
pytest-asyncio
|
||||
pytest-playwright
|
||||
tavily-python
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Util functions for tests"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gateway_url():
|
||||
"""Get the gateway URL from the environment variable"""
|
||||
if "INVARIANT_GATEWAY_API_URL" in os.environ:
|
||||
return os.environ["INVARIANT_GATEWAY_API_URL"]
|
||||
raise ValueError("Please set the INVARIANT_GATEWAY_API_URL environment variable")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def explorer_api_url():
|
||||
"""Get the explorer API URL from the environment variable"""
|
||||
if "INVARIANT_API_URL" in os.environ:
|
||||
return os.environ["INVARIANT_API_URL"]
|
||||
raise ValueError("Please set the INVARIANT_API_URL environment variable")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def playwright(scope="session"):
|
||||
"""Fixture to create a Playwright instance"""
|
||||
async with async_playwright() as playwright_instance:
|
||||
yield playwright_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def browser(playwright, scope="session"):
|
||||
"""Fixture to create a browser instance"""
|
||||
firefox_browser = await playwright.firefox.launch(headless=True)
|
||||
yield firefox_browser
|
||||
await firefox_browser.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def context(browser):
|
||||
"""Fixture to create a browser context"""
|
||||
browser_context = await browser.new_context(ignore_https_errors=True)
|
||||
yield browser_context
|
||||
await browser_context.close()
|
||||
Reference in New Issue
Block a user