Fix broken integration tests.

This commit is contained in:
Hemang
2025-05-06 20:35:35 +05:30
committed by Hemang Sarkar
parent dc9ac9c3c6
commit aec7808e3e
9 changed files with 52 additions and 40 deletions
+3 -2
View File
@@ -26,10 +26,11 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest
pip install .
- name: Run unit tests
run: ./run.sh unit-tests -s -vv
run: |
pip install -r tests/unit_tests/requirements.txt
./run.sh unit-tests -s -vv
continue-on-error: true
- name: Run integration tests
+1 -4
View File
@@ -14,10 +14,7 @@ from gateway.common.config_manager import (
GatewayConfigManager,
extract_guardrails_from_header,
)
from gateway.common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
from gateway.common.request_context import RequestContext
from gateway.converters.anthropic_to_invariant import (
@@ -12,11 +12,10 @@ from typing import Dict, List
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_anthropic_client
import anthropic
import pytest
import requests
from utils import get_anthropic_client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -249,6 +248,7 @@ async def test_streaming_response_with_tool_call(
messages = [{"role": "user", "content": query}]
response = weather_agent.get_streaming_response(messages)
assert response is not None
assert response[0][0].type == "text"
assert response[0][1].type == "tool_use"
@@ -303,7 +303,7 @@ async def test_response_with_tool_call_with_image(
"""Test the chat completion with image for the weather agent."""
weather_agent = WeatherAgent(gateway_url, push_to_explorer)
image_path = Path(__file__).parent.parent / "resources" / "images" / "new-york.jpeg"
image_path = Path(__file__).parent.parent / "resources" / "images" / "new-york.jpg"
with image_path.open("rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
@@ -367,4 +367,3 @@ async def test_response_with_tool_call_with_image(
].lower()
)
assert trace_messages[3]["role"] == "tool"
assert trace_messages[4]["role"] == "assistant"
@@ -2,17 +2,16 @@
import os
import sys
import uuid
import time
import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset
import pytest
import requests
from anthropic import APIStatusError, BadRequestError
from utils import add_guardrail_to_dataset, create_dataset, get_anthropic_client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -164,7 +163,7 @@ async def test_tool_call_guardrail_from_file(
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.messages.create(**request, stream=False)
_ = client.messages.create(**request, stream=False)
assert exc_info.value.status_code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
@@ -174,10 +173,10 @@ async def test_tool_call_guardrail_from_file(
else:
with pytest.raises(APIStatusError) as exc_info:
chat_response = client.messages.create(**request, stream=True)
for _ in chat_response:
response = client.messages.create(**request, stream=True)
for _ in response:
pass
assert (
"[Invariant] The response did not pass the guardrails"
in exc_info.value.message
@@ -535,10 +534,12 @@ async def test_preguardrailing_with_guardrails_from_explorer(
else:
if do_stream:
_ = client.messages.create(
response = client.messages.create(
**request,
stream=True,
)
for _ in response:
pass
else:
_ = client.messages.create(
**request,
@@ -2,17 +2,16 @@
import os
import sys
import uuid
import time
import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset
import pytest
import requests
from openai import BadRequestError, APIError
from openai import APIError, BadRequestError
from utils import add_guardrail_to_dataset, create_dataset, get_open_ai_client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -532,10 +531,12 @@ async def test_preguardrailing_with_guardrails_from_explorer(
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
_ = client.chat.completions.create(
response = client.chat.completions.create(
**request,
stream=True,
)
for _ in response:
pass
else:
_ = client.chat.completions.create(
**request,
@@ -1,10 +1,12 @@
import pytest
import uuid
from litellm import completion
import litellm
import time
import requests
"""Test the chat completions gateway calls with tool calling through litellm."""
import os
import time
import uuid
import pytest
import requests
from litellm import completion
MODEL_API_KEYS = {
"openai/gpt-4o": "OPENAI_API_KEY",
@@ -12,13 +14,21 @@ MODEL_API_KEYS = {
"anthropic/claude-3-5-haiku-20241022": "ANTHROPIC_API_KEY",
}
@pytest.mark.parametrize("litellm_model", MODEL_API_KEYS.keys(),)
@pytest.mark.parametrize(
"litellm_model",
MODEL_API_KEYS.keys(),
)
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(False, False)],
)
async def test_chat_completion(
explorer_api_url: str, litellm_model: str, gateway_url: str, do_stream: bool, push_to_explorer: bool
explorer_api_url: str,
litellm_model: str,
gateway_url: str,
do_stream: bool,
push_to_explorer: bool,
):
"""Test the chat completions gateway calls with tool calling through litellm."""
# Check if the API key is set in the environment variables
@@ -28,21 +38,23 @@ async def test_chat_completion(
if not api_key:
pytest.skip(f"Skipping {litellm_model} because {api_key_env_var} is not set")
dataset_name = f"test-dataset-litellm-{litellm_model}-{uuid.uuid4()}"
base_url = (f"{gateway_url}/api/v1/gateway/{dataset_name}"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway")
base_url = (
f"{gateway_url}/api/v1/gateway/{dataset_name}"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway"
)
base_url += "/" + litellm_model.split("/")[0] #add provider name
base_url += "/" + litellm_model.split("/")[0] # add provider name
if litellm_model.split("/")[0] == "gemini":
base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" #gemini expects the model name in the url
base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" # gemini expects the model name in the url
print(f"base_url: {base_url}")
chat_response = completion(
model=litellm_model,
messages=[{"role": "user", "content": "What is the capital of France?"}],
extra_headers= {"Invariant-Authorization": "Bearer <some-key>"},
extra_headers={
"Invariant-Authorization": f"Bearer {os.environ['INVARIANT_API_KEY']}"
},
stream=do_stream,
base_url=base_url,
)
@@ -92,4 +104,4 @@ async def test_chat_completion(
"role": "assistant",
"content": expected_assistant_message,
},
]
]
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

+1
View File
@@ -0,0 +1 @@
fastapi