mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-26 16:37:47 +02:00
Fix broken integration tests.
This commit is contained in:
@@ -26,10 +26,11 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install .
|
||||
|
||||
- name: Run unit tests
|
||||
run: ./run.sh unit-tests -s -vv
|
||||
run: |
|
||||
pip install -r tests/unit_tests/requirements.txt
|
||||
./run.sh unit-tests -s -vv
|
||||
continue-on-error: true
|
||||
|
||||
- name: Run integration tests
|
||||
|
||||
@@ -14,10 +14,7 @@ from gateway.common.config_manager import (
|
||||
GatewayConfigManager,
|
||||
extract_guardrails_from_header,
|
||||
)
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
|
||||
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.converters.anthropic_to_invariant import (
|
||||
|
||||
@@ -12,11 +12,10 @@ from typing import Dict, List
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_anthropic_client
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
import requests
|
||||
from utils import get_anthropic_client
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -249,6 +248,7 @@ async def test_streaming_response_with_tool_call(
|
||||
|
||||
messages = [{"role": "user", "content": query}]
|
||||
response = weather_agent.get_streaming_response(messages)
|
||||
|
||||
assert response is not None
|
||||
assert response[0][0].type == "text"
|
||||
assert response[0][1].type == "tool_use"
|
||||
@@ -303,7 +303,7 @@ async def test_response_with_tool_call_with_image(
|
||||
"""Test the chat completion with image for the weather agent."""
|
||||
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
|
||||
|
||||
image_path = Path(__file__).parent.parent / "resources" / "images" / "new-york.jpeg"
|
||||
image_path = Path(__file__).parent.parent / "resources" / "images" / "new-york.jpg"
|
||||
|
||||
with image_path.open("rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -367,4 +367,3 @@ async def test_response_with_tool_call_with_image(
|
||||
].lower()
|
||||
)
|
||||
assert trace_messages[3]["role"] == "tool"
|
||||
assert trace_messages[4]["role"] == "assistant"
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from anthropic import APIStatusError, BadRequestError
|
||||
from utils import add_guardrail_to_dataset, create_dataset, get_anthropic_client
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -164,7 +163,7 @@ async def test_tool_call_guardrail_from_file(
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.messages.create(**request, stream=False)
|
||||
_ = client.messages.create(**request, stream=False)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
@@ -174,10 +173,10 @@ async def test_tool_call_guardrail_from_file(
|
||||
|
||||
else:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
chat_response = client.messages.create(**request, stream=True)
|
||||
|
||||
for _ in chat_response:
|
||||
response = client.messages.create(**request, stream=True)
|
||||
for _ in response:
|
||||
pass
|
||||
|
||||
assert (
|
||||
"[Invariant] The response did not pass the guardrails"
|
||||
in exc_info.value.message
|
||||
@@ -535,10 +534,12 @@ async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
|
||||
else:
|
||||
if do_stream:
|
||||
_ = client.messages.create(
|
||||
response = client.messages.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in response:
|
||||
pass
|
||||
else:
|
||||
_ = client.messages.create(
|
||||
**request,
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from openai import BadRequestError, APIError
|
||||
from openai import APIError, BadRequestError
|
||||
from utils import add_guardrail_to_dataset, create_dataset, get_open_ai_client
|
||||
|
||||
# Pytest plugins
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
@@ -532,10 +531,12 @@ async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
else:
|
||||
if do_stream:
|
||||
_ = client.chat.completions.create(
|
||||
response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in response:
|
||||
pass
|
||||
else:
|
||||
_ = client.chat.completions.create(
|
||||
**request,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import pytest
|
||||
import uuid
|
||||
from litellm import completion
|
||||
import litellm
|
||||
import time
|
||||
import requests
|
||||
"""Test the chat completions gateway calls with tool calling through litellm."""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from litellm import completion
|
||||
|
||||
MODEL_API_KEYS = {
|
||||
"openai/gpt-4o": "OPENAI_API_KEY",
|
||||
@@ -12,13 +14,21 @@ MODEL_API_KEYS = {
|
||||
"anthropic/claude-3-5-haiku-20241022": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("litellm_model", MODEL_API_KEYS.keys(),)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"litellm_model",
|
||||
MODEL_API_KEYS.keys(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(False, False)],
|
||||
)
|
||||
async def test_chat_completion(
|
||||
explorer_api_url: str, litellm_model: str, gateway_url: str, do_stream: bool, push_to_explorer: bool
|
||||
explorer_api_url: str,
|
||||
litellm_model: str,
|
||||
gateway_url: str,
|
||||
do_stream: bool,
|
||||
push_to_explorer: bool,
|
||||
):
|
||||
"""Test the chat completions gateway calls with tool calling through litellm."""
|
||||
# Check if the API key is set in the environment variables
|
||||
@@ -28,21 +38,23 @@ async def test_chat_completion(
|
||||
if not api_key:
|
||||
pytest.skip(f"Skipping {litellm_model} because {api_key_env_var} is not set")
|
||||
|
||||
|
||||
dataset_name = f"test-dataset-litellm-{litellm_model}-{uuid.uuid4()}"
|
||||
base_url = (f"{gateway_url}/api/v1/gateway/{dataset_name}"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway")
|
||||
base_url = (
|
||||
f"{gateway_url}/api/v1/gateway/{dataset_name}"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway"
|
||||
)
|
||||
|
||||
base_url += "/" + litellm_model.split("/")[0] #add provider name
|
||||
base_url += "/" + litellm_model.split("/")[0] # add provider name
|
||||
if litellm_model.split("/")[0] == "gemini":
|
||||
base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" #gemini expects the model name in the url
|
||||
base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" # gemini expects the model name in the url
|
||||
|
||||
print(f"base_url: {base_url}")
|
||||
chat_response = completion(
|
||||
model=litellm_model,
|
||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||
extra_headers= {"Invariant-Authorization": "Bearer <some-key>"},
|
||||
extra_headers={
|
||||
"Invariant-Authorization": f"Bearer {os.environ['INVARIANT_API_KEY']}"
|
||||
},
|
||||
stream=do_stream,
|
||||
base_url=base_url,
|
||||
)
|
||||
@@ -92,4 +104,4 @@ async def test_chat_completion(
|
||||
"role": "assistant",
|
||||
"content": expected_assistant_message,
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.4 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 96 KiB |
@@ -0,0 +1 @@
|
||||
fastapi
|
||||
Reference in New Issue
Block a user