Remove unnecessary playwright dependency for tests.

This commit is contained in:
Hemang
2025-03-13 09:29:05 +01:00
committed by Hemang Sarkar
parent 31b4e9bba7
commit c4ecc01a59
12 changed files with 202 additions and 205 deletions
+17 -3
View File
@@ -292,10 +292,24 @@ This will launch Gateway at [http://localhost:8005/api/v1/gateway/](http://local
By default Gateway points to the public Explorer instance at `explorer.invariantlabs.ai`. To point it to your local Explorer instance, modify the `INVARIANT_API_URL` value inside `.env`. Follow instructions in `.env` on how to point to the local instance.
### **Run Tests**
### **Run Unit Tests**
To run tests, execute:
To run the unit tests, execute:
```bash
./run.sh tests
bash run.sh unit-tests
```
### **Run Integration Tests**
To run the integration tests, execute:
```bash
bash run.sh integration-tests
```
To run a subset of the integration tests, execute:
```bash
bash run.sh integration-tests open_ai/test_chat_with_tool_call.py
```
+2 -2
View File
@@ -1,4 +1,4 @@
FROM mcr.microsoft.com/playwright/python:v1.50.0-noble
FROM python:3.11-slim
RUN mkdir -p /tests
COPY ./integration/requirements.txt /tests/requirements.txt
@@ -6,4 +6,4 @@ WORKDIR /tests
RUN pip install --upgrade pip
RUN pip install --no-cache-dir -r requirements.txt
ENTRYPOINT ["pytest", "--capture=tee-sys", "--tracing", "off", "--junit-xml=/tests/results/test-results-all.xml", "-s", "-vv"]
ENTRYPOINT ["pytest", "--capture=tee-sys", "--junit-xml=/tests/results/test-results-all.xml", "-s", "-vv"]
@@ -1,9 +1,9 @@
"""Test the Anthropic gateway with Invariant key in the ANTHROPIC_API_KEY."""
import datetime
import os
import sys
import time
import uuid
from unittest.mock import patch
# Add integration folder (parent) to sys.path
@@ -11,10 +11,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
import pytest
import requests
from httpx import Client
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -23,13 +22,11 @@ pytest_plugins = ("pytest_asyncio",)
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
async def test_gateway_with_invariant_key_in_anthropic_key_header(
context, gateway_url, explorer_api_url
gateway_url, explorer_api_url
):
"""Test the Anthropic gateway with Invariant key in the Anthropic key"""
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
dataset_name = "claude_header_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
with patch.dict(
os.environ,
{
@@ -60,17 +57,18 @@ async def test_gateway_with_invariant_key_in_anthropic_key_header(
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
get_trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
get_trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
)
trace = await get_trace_response.json()
trace = get_trace_response.json()
assert trace["messages"] == [
{
"role": "user",
@@ -1,11 +1,11 @@
"""Test the Anthropic messages API with tool call for the weather agent."""
import base64
import datetime
import json
import os
import sys
import time
import uuid
from pathlib import Path
from typing import Dict, List
@@ -14,10 +14,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
import pytest
import requests
from httpx import Client
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -26,9 +25,7 @@ class WeatherAgent:
"""Weather agent to get the current weather in a given location."""
def __init__(self, gateway_url, push_to_explorer):
self.dataset_name = "claude_weather_agent_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
self.dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
self.client = anthropic.Anthropic(
http_client=Client(
@@ -183,9 +180,7 @@ class WeatherAgent:
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_response_with_tool_call(
context, explorer_api_url, gateway_url, push_to_explorer
):
async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_explorer):
"""Test the chat completion without streaming for the weather agent."""
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
@@ -213,17 +208,18 @@ async def test_response_with_tool_call(
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
)
trace = await trace_response.json()
trace = trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
@@ -248,7 +244,7 @@ async def test_response_with_tool_call(
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_streaming_response_with_tool_call(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""Test the chat completion with streaming for the weather agent."""
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
@@ -271,18 +267,19 @@ async def test_streaming_response_with_tool_call(
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
)
trace = await trace_response.json()
trace = trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[0]["content"] == query
@@ -306,7 +303,7 @@ async def test_streaming_response_with_tool_call(
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_response_with_tool_call_with_image(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""Test the chat completion with image for the weather agent."""
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
@@ -348,17 +345,18 @@ async def test_response_with_tool_call_with_image(
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{weather_agent.dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
trace = traces[-1]
trace_id = trace["id"]
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
)
trace = await trace_response.json()
trace = trace_response.json()
trace_messages = trace["messages"]
assert trace_messages[0]["role"] == "user"
assert trace_messages[1]["role"] == "assistant"
@@ -1,19 +1,18 @@
"""Tests for the Anthropic API without tool call."""
import datetime
import os
import sys
import time
import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import anthropic
import pytest
import requests
from httpx import Client
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -23,12 +22,10 @@ pytest_plugins = ("pytest_asyncio",)
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_response_without_tool_call(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""Test the Anthropic gateway without tool calling."""
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
client = anthropic.Anthropic(
@@ -64,19 +61,21 @@ async def test_response_without_tool_call(
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == len(queries)
for index, trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
assert trace["messages"] == [
{"role": "user", "content": queries[index]},
{"role": "assistant", "content": responses[index]},
@@ -88,12 +87,10 @@ async def test_response_without_tool_call(
)
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_streaming_response_without_tool_call(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""Test the Anthropic gateway without tool calling."""
dataset_name = "claude_streaming_response_without_tool_call_test" + str(
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
)
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
invariant_api_key = os.environ.get("INVARIANT_API_KEY", "None")
client = anthropic.Anthropic(
@@ -133,19 +130,21 @@ async def test_streaming_response_without_tool_call(
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == len(queries)
for index, trace in enumerate(traces):
trace_id = trace["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
assert trace["messages"] == [
{"role": "user", "content": queries[index]},
{"role": "assistant", "content": responses[index]},
+20
View File
@@ -0,0 +1,20 @@
"""Util functions for tests"""
import os
import pytest
@pytest.fixture
def gateway_url():
"""Get the gateway URL from the environment variable"""
if "INVARIANT_GATEWAY_API_URL" in os.environ:
return os.environ["INVARIANT_GATEWAY_API_URL"]
raise ValueError("Please set the INVARIANT_GATEWAY_API_URL environment variable")
@pytest.fixture
def explorer_api_url():
"""Get the explorer API URL from the environment variable"""
if "INVARIANT_API_URL" in os.environ:
return os.environ["INVARIANT_API_URL"]
raise ValueError("Please set the INVARIANT_API_URL environment variable")
@@ -9,14 +9,14 @@ import uuid
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import requests
from google import genai
from google.genai import types
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
def set_light_values(brightness: int, color_temp: str) -> dict[str, int | str]:
"""Set the brightness and color temperature of a room light. (mock API).
@@ -49,32 +49,35 @@ SET_LIGHT_VALUES_TOOL_CALL = {
}
async def _verify_trace_from_explorer(
context, explorer_api_url, dataset_name, expected_final_assistant_message
def _verify_trace_from_explorer(
explorer_api_url, dataset_name, expected_final_assistant_message
) -> None:
# Fetch the trace ids for the dataset.
# There will be 2 traces - the first will contain the system instruction, user prompt
# and the assistant tool call.
# The second will contain the system instruction, user prompt, the assistant tool call,
# the tool response and the assistant response.
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 2
trace_id_1 = traces[0]["id"]
trace_id_2 = traces[1]["id"]
# Fetch the trace
trace_response_1 = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id_1}"
trace_response_1 = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id_1}",
timeout=5,
)
trace_1 = await trace_response_1.json()
trace_1 = trace_response_1.json()
trace_response_2 = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id_2}"
trace_response_2 = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id_2}",
timeout=5,
)
trace_2 = await trace_response_2.json()
trace_2 = trace_response_2.json()
# Verify the trace messages
assert trace_1["messages"] == [
@@ -133,13 +136,13 @@ async def _verify_trace_from_explorer(
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_generate_content_with_tool_call(
context, explorer_api_url, gateway_url, push_to_explorer, do_stream
explorer_api_url, gateway_url, push_to_explorer, do_stream
):
"""
Test the generate content gateway calls with tool calling and response processing
without streaming.
"""
dataset_name = "test-dataset-gemini-tool-call-" + str(uuid.uuid4())
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
@@ -190,6 +193,6 @@ async def test_generate_content_with_tool_call(
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
await _verify_trace_from_explorer(
context, explorer_api_url, dataset_name, expected_final_assistant_message
_verify_trace_from_explorer(
explorer_api_url, dataset_name, expected_final_assistant_message
)
@@ -12,10 +12,9 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import PIL.Image
import requests
from google import genai
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -26,10 +25,10 @@ pytest_plugins = ("pytest_asyncio",)
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_generate_content(
context, explorer_api_url, gateway_url, do_stream, push_to_explorer
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test the generate content gateway calls without tool calling."""
dataset_name = "test-dataset-gemini-" + str(uuid.uuid4())
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
@@ -78,18 +77,19 @@ async def test_generate_content(
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
)
trace = await trace_response.json()
trace = trace_response.json()
# Verify the trace messages
assert trace["messages"] == [
@@ -111,10 +111,10 @@ async def test_generate_content(
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
@pytest.mark.parametrize("push_to_explorer", [True, False])
async def test_generate_content_with_image(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""Test that generate content gateway calls work with image."""
dataset_name = "test-dataset-gemini-" + str(uuid.uuid4())
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
@@ -147,18 +147,19 @@ async def test_generate_content_with_image(
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}", timeout=5
)
trace = await trace_response.json()
trace = trace_response.json()
# Verify the trace messages
assert len(trace["messages"]) == 2
assert trace["messages"][0]["role"] == "user"
@@ -175,10 +176,10 @@ async def test_generate_content_with_image(
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
async def test_generate_content_with_invariant_key_in_gemini_key_header(
context, explorer_api_url, gateway_url
explorer_api_url, gateway_url
):
"""Test the generate content gateway calls with the Invariant API Key in the Gemini Key header."""
dataset_name = "test-dataset-gemini-" + str(uuid.uuid4())
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
gemini_api_key = os.getenv("GEMINI_API_KEY")
with patch.dict(
os.environ,
@@ -208,18 +209,20 @@ async def test_generate_content_with_invariant_key_in_gemini_key_header(
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
# Verify the trace messages
assert trace["messages"] == [
@@ -10,11 +10,10 @@ import uuid
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import requests
from httpx import Client
from openai import OpenAI
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -22,13 +21,13 @@ pytest_plugins = ("pytest_asyncio",)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_chat_completion_with_tool_call_without_streaming(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""
Test the chat completions gateway calls with tool calling and response processing
without streaming.
"""
dataset_name = "test-dataset-open-ai-tool-call-" + str(uuid.uuid4())
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
@@ -106,18 +105,20 @@ async def test_chat_completion_with_tool_call_without_streaming(
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
for message in trace["messages"]:
message.pop("annotations", None)
@@ -138,13 +139,13 @@ async def test_chat_completion_with_tool_call_without_streaming(
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("push_to_explorer", [False, True])
async def test_chat_completion_with_tool_call_with_streaming(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""
Test the chat completions gateway calls with tool calling and response processing
while streaming.
"""
dataset_name = "test-dataset-open-ai-tool-call-" + str(uuid.uuid4())
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
@@ -229,18 +230,20 @@ async def test_chat_completion_with_tool_call_with_streaming(
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
# Verify the trace messages
expected_messages = history + [final_response]
@@ -12,11 +12,10 @@ from unittest.mock import patch
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import requests
from httpx import Client
from openai import NotFoundError, OpenAI
from util import * # Needed for pytest fixtures
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -27,10 +26,10 @@ pytest_plugins = ("pytest_asyncio",)
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_chat_completion(
context, explorer_api_url, gateway_url, do_stream, push_to_explorer
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test the chat completions gateway calls without tool calling."""
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
@@ -66,18 +65,20 @@ async def test_chat_completion(
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
for message in trace["messages"]:
message.pop("annotations", None)
@@ -98,10 +99,10 @@ async def test_chat_completion(
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize("push_to_explorer", [True, False])
async def test_chat_completion_with_image(
context, explorer_api_url, gateway_url, push_to_explorer
explorer_api_url, gateway_url, push_to_explorer
):
"""Test the chat completions gateway works with image."""
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
@@ -149,18 +150,20 @@ async def test_chat_completion_with_image(
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
for message in trace["messages"]:
message.pop("annotations", None)
@@ -181,10 +184,10 @@ async def test_chat_completion_with_image(
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
async def test_chat_completion_with_invariant_key_in_openai_key_header(
context, explorer_api_url, gateway_url
explorer_api_url, gateway_url
):
"""Test the chat completions gateway calls with the Invariant API Key in the OpenAI Key header."""
dataset_name = "test-dataset-open-ai-" + str(uuid.uuid4())
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
openai_api_key = os.getenv("OPENAI_API_KEY")
with patch.dict(
os.environ,
@@ -210,18 +213,20 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = await context.request.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces"
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = await traces_response.json()
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = await context.request.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}"
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = await trace_response.json()
trace = trace_response.json()
for message in trace["messages"]:
message.pop("annotations", None)
@@ -243,14 +248,14 @@ async def test_chat_completion_with_invariant_key_in_openai_key_header(
@pytest.mark.parametrize("do_stream", [True, False])
async def test_chat_completion_with_openai_exception(gateway_url, do_stream):
"""Test the chat completions gateway call when OpenAI API fails."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": "Bearer <some-key>"
}, # This key is not used for local tests
),
base_url=f"{gateway_url}/api/v1/gateway/{"test-dataset-open-ai-" + str(uuid.uuid4())}/openai",
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai",
)
with pytest.raises(Exception) as exc_info:
-1
View File
@@ -4,5 +4,4 @@ openai
pillow
pytest
pytest-asyncio
pytest-playwright
tavily-python
-45
View File
@@ -1,45 +0,0 @@
"""Util functions for tests"""
import os
import pytest
from playwright.async_api import async_playwright
@pytest.fixture
def gateway_url():
"""Get the gateway URL from the environment variable"""
if "INVARIANT_GATEWAY_API_URL" in os.environ:
return os.environ["INVARIANT_GATEWAY_API_URL"]
raise ValueError("Please set the INVARIANT_GATEWAY_API_URL environment variable")
@pytest.fixture
def explorer_api_url():
"""Get the explorer API URL from the environment variable"""
if "INVARIANT_API_URL" in os.environ:
return os.environ["INVARIANT_API_URL"]
raise ValueError("Please set the INVARIANT_API_URL environment variable")
@pytest.fixture
async def playwright(scope="session"):
"""Fixture to create a Playwright instance"""
async with async_playwright() as playwright_instance:
yield playwright_instance
@pytest.fixture
async def browser(playwright, scope="session"):
"""Fixture to create a browser instance"""
firefox_browser = await playwright.firefox.launch(headless=True)
yield firefox_browser
await firefox_browser.close()
@pytest.fixture
async def context(browser):
"""Fixture to create a browser context"""
browser_context = await browser.new_context(ignore_https_errors=True)
yield browser_context
await browser_context.close()