diff --git a/.github/workflows/tests_ci.yml b/.github/workflows/tests_ci.yml index 3fe236b..96a2955 100644 --- a/.github/workflows/tests_ci.yml +++ b/.github/workflows/tests_ci.yml @@ -26,10 +26,11 @@ jobs: run: | python -m pip install --upgrade pip pip install pytest - pip install . - name: Run unit tests - run: ./run.sh unit-tests -s -vv + run: | + pip install -r tests/unit_tests/requirements.txt + ./run.sh unit-tests -s -vv continue-on-error: true - name: Run integration tests diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index e294a24..bbc2af7 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -14,10 +14,7 @@ from gateway.common.config_manager import ( GatewayConfigManager, extract_guardrails_from_header, ) -from gateway.common.constants import ( - CLIENT_TIMEOUT, - IGNORED_HEADERS, -) +from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet from gateway.common.request_context import RequestContext from gateway.converters.anthropic_to_invariant import ( diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index abc3937..5e42881 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -12,11 +12,10 @@ from typing import Dict, List # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_anthropic_client - import anthropic import pytest import requests +from utils import get_anthropic_client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -249,6 +248,7 @@ async def test_streaming_response_with_tool_call( messages = [{"role": "user", "content": query}] response = weather_agent.get_streaming_response(messages) + assert response is not None assert response[0][0].type == "text" assert response[0][1].type == "tool_use" @@ -303,7 +303,7 @@ async def test_response_with_tool_call_with_image( """Test the chat completion with image for the weather agent.""" weather_agent = WeatherAgent(gateway_url, push_to_explorer) - image_path = Path(__file__).parent.parent / "resources" / "images" / "new-york.jpeg" + image_path = Path(__file__).parent.parent / "resources" / "images" / "new-york.jpg" with image_path.open("rb") as image_file: base64_image = base64.b64encode(image_file.read()).decode("utf-8") @@ -367,4 +367,3 @@ async def test_response_with_tool_call_with_image( ].lower() ) assert trace_messages[3]["role"] == "tool" - assert trace_messages[4]["role"] == "assistant" diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index 14f06e5..77bcf1f 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -2,17 +2,16 @@ import os import sys -import uuid import time +import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_anthropic_client, create_dataset, add_guardrail_to_dataset - import pytest import requests from anthropic import APIStatusError, BadRequestError +from utils import add_guardrail_to_dataset, create_dataset, get_anthropic_client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -164,7 +163,7 @@ async def test_tool_call_guardrail_from_file( if not do_stream: with pytest.raises(BadRequestError) as exc_info: - chat_response = client.messages.create(**request, stream=False) + _ = client.messages.create(**request, stream=False) assert exc_info.value.status_code == 400 assert "[Invariant] The response did not pass the guardrails" in str( @@ -174,10 +173,10 @@ async def test_tool_call_guardrail_from_file( else: with pytest.raises(APIStatusError) as exc_info: - chat_response = client.messages.create(**request, stream=True) - - for _ in chat_response: + response = client.messages.create(**request, stream=True) + for _ in response: pass + assert ( "[Invariant] The response did not pass the guardrails" in exc_info.value.message @@ -535,10 +534,12 @@ async def test_preguardrailing_with_guardrails_from_explorer( else: if do_stream: - _ = client.messages.create( + response = client.messages.create( **request, stream=True, ) + for _ in response: + pass else: _ = client.messages.create( **request, diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index a418f25..402b8a1 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -2,17 +2,16 @@ import os import sys -import uuid import time +import uuid # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_open_ai_client, create_dataset, add_guardrail_to_dataset - import pytest import requests -from openai import BadRequestError, APIError +from openai import APIError, BadRequestError +from utils import add_guardrail_to_dataset, create_dataset, get_open_ai_client # Pytest plugins pytest_plugins = ("pytest_asyncio",) @@ -532,10 +531,12 @@ async def test_preguardrailing_with_guardrails_from_explorer( assert "pun detected in user message" in str(exc_info.value) else: if do_stream: - _ = client.chat.completions.create( + response = client.chat.completions.create( **request, stream=True, ) + for _ in response: + pass else: _ = client.chat.completions.create( **request, diff --git a/tests/integration/litellm/test_chat_without_tool_call.py b/tests/integration/litellm/test_chat_without_tool_call.py index 879cd73..ee831a9 100644 --- a/tests/integration/litellm/test_chat_without_tool_call.py +++ b/tests/integration/litellm/test_chat_without_tool_call.py @@ -1,10 +1,12 @@ -import pytest -import uuid -from litellm import completion -import litellm -import time -import requests +"""Test the chat completions gateway calls with tool calling through litellm.""" + import os +import time +import uuid + +import pytest +import requests +from litellm import completion MODEL_API_KEYS = { "openai/gpt-4o": "OPENAI_API_KEY", @@ -12,13 +14,21 @@ MODEL_API_KEYS = { "anthropic/claude-3-5-haiku-20241022": "ANTHROPIC_API_KEY", } -@pytest.mark.parametrize("litellm_model", MODEL_API_KEYS.keys(),) + +@pytest.mark.parametrize( + "litellm_model", + MODEL_API_KEYS.keys(), +) @pytest.mark.parametrize( "do_stream, push_to_explorer", [(False, False)], ) async def test_chat_completion( - explorer_api_url: str, litellm_model: str, gateway_url: str, do_stream: bool, push_to_explorer: bool + explorer_api_url: str, + litellm_model: str, + gateway_url: str, + do_stream: bool, + push_to_explorer: bool, ): """Test the chat completions gateway calls with tool calling through litellm.""" # Check if the API key is set in the environment variables @@ -28,21 +38,23 @@ async def test_chat_completion( if not api_key: pytest.skip(f"Skipping {litellm_model} because {api_key_env_var} is not set") - dataset_name = f"test-dataset-litellm-{litellm_model}-{uuid.uuid4()}" - base_url = (f"{gateway_url}/api/v1/gateway/{dataset_name}" - if push_to_explorer - else f"{gateway_url}/api/v1/gateway") + base_url = ( + f"{gateway_url}/api/v1/gateway/{dataset_name}" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway" + ) - base_url += "/" + litellm_model.split("/")[0] #add provider name + base_url += "/" + litellm_model.split("/")[0] # add provider name if litellm_model.split("/")[0] == "gemini": - base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" #gemini expects the model name in the url + base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" # gemini expects the model name in the url - print(f"base_url: {base_url}") chat_response = completion( model=litellm_model, messages=[{"role": "user", "content": "What is the capital of France?"}], - extra_headers= {"Invariant-Authorization": "Bearer "}, + extra_headers={ + "Invariant-Authorization": f"Bearer {os.environ['INVARIANT_API_KEY']}" + }, stream=do_stream, base_url=base_url, ) @@ -92,4 +104,4 @@ async def test_chat_completion( "role": "assistant", "content": expected_assistant_message, }, - ] \ No newline at end of file + ] diff --git a/tests/integration/resources/images/new-york.jpeg b/tests/integration/resources/images/new-york.jpeg deleted file mode 100644 index 80bc4a6..0000000 Binary files a/tests/integration/resources/images/new-york.jpeg and /dev/null differ diff --git a/tests/integration/resources/images/new-york.jpg b/tests/integration/resources/images/new-york.jpg new file mode 100644 index 0000000..9ddb87b Binary files /dev/null and b/tests/integration/resources/images/new-york.jpg differ diff --git a/tests/unit_tests/requirements.txt b/tests/unit_tests/requirements.txt new file mode 100644 index 0000000..170703d --- /dev/null +++ b/tests/unit_tests/requirements.txt @@ -0,0 +1 @@ +fastapi \ No newline at end of file