test: add 29 unit tests and remove lazy imports

This commit is contained in:
RheagalFire
2026-05-19 01:50:40 +05:30
parent 6e6fdbcf28
commit a4833908ef
2 changed files with 235 additions and 4 deletions
@@ -2,6 +2,8 @@
from typing import Any
import litellm
from agentic_security.llm_providers.base import (
BaseLLMProvider,
LLMMessage,
@@ -82,8 +84,6 @@ class LiteLLMProvider(BaseLLMProvider):
return await self.chat(messages, **kwargs)
async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
import litellm
try:
response = await litellm.acompletion(
messages=self._messages_to_dicts(messages),
@@ -101,8 +101,6 @@ class LiteLLMProvider(BaseLLMProvider):
return self.sync_chat(messages, **kwargs)
def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
import litellm
try:
response = litellm.completion(
messages=self._messages_to_dicts(messages),
@@ -0,0 +1,233 @@
"""Tests for LiteLLM provider."""
import pytest
from inline_snapshot import snapshot
from unittest.mock import MagicMock, AsyncMock, patch
from agentic_security.llm_providers.litellm_provider import LiteLLMProvider
from agentic_security.llm_providers.base import (
LLMMessage,
LLMProviderError,
LLMRateLimitError,
)
def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop",
prompt_tokens=10, completion_tokens=5, total_tokens=15):
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = content
resp.choices[0].finish_reason = finish_reason
resp.model = model
resp.usage.prompt_tokens = prompt_tokens
resp.usage.completion_tokens = completion_tokens
resp.usage.total_tokens = total_tokens
return resp
class TestLiteLLMProviderInit:
def test_default_model(self):
provider = LiteLLMProvider()
assert provider.model == snapshot("openai/gpt-4o-mini")
def test_custom_model(self):
provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6")
assert provider.model == snapshot("anthropic/claude-sonnet-4-6")
def test_no_api_key_required(self):
provider = LiteLLMProvider()
assert provider._api_key is None
def test_api_key_stored(self):
provider = LiteLLMProvider(api_key="sk-test")
assert provider._api_key == snapshot("sk-test")
def test_api_base_stored(self):
provider = LiteLLMProvider(api_base="http://localhost:4000")
assert provider._api_base == snapshot("http://localhost:4000")
class TestLiteLLMProviderCallKwargs:
def test_drop_params_always_true(self):
provider = LiteLLMProvider()
kwargs = provider._call_kwargs()
assert kwargs["drop_params"] is True
def test_api_key_forwarded_when_set(self):
provider = LiteLLMProvider(api_key="sk-test")
kwargs = provider._call_kwargs()
assert kwargs["api_key"] == snapshot("sk-test")
def test_api_key_omitted_when_none(self):
provider = LiteLLMProvider()
kwargs = provider._call_kwargs()
assert "api_key" not in kwargs
def test_api_base_forwarded_when_set(self):
provider = LiteLLMProvider(api_base="http://localhost:4000")
kwargs = provider._call_kwargs()
assert kwargs["api_base"] == snapshot("http://localhost:4000")
def test_api_base_omitted_when_none(self):
provider = LiteLLMProvider()
kwargs = provider._call_kwargs()
assert "api_base" not in kwargs
def test_model_in_kwargs(self):
provider = LiteLLMProvider(model="groq/llama-3.3-70b-versatile")
kwargs = provider._call_kwargs()
assert kwargs["model"] == snapshot("groq/llama-3.3-70b-versatile")
class TestLiteLLMProviderMethods:
def test_get_supported_models(self):
models = LiteLLMProvider.get_supported_models()
assert "openai/gpt-4o" in models
assert "anthropic/claude-sonnet-4-6" in models
def test_messages_to_dicts(self):
provider = LiteLLMProvider()
messages = [
LLMMessage(role="system", content="Be helpful"),
LLMMessage(role="user", content="Hello"),
]
result = provider._messages_to_dicts(messages)
assert result == snapshot(
[
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "Hello"},
]
)
def test_parse_response(self):
provider = LiteLLMProvider()
resp = _mock_response(content="Hi!", model="openai/gpt-4o")
result = provider._parse_response(resp)
assert result.content == snapshot("Hi!")
assert result.model == snapshot("openai/gpt-4o")
assert result.finish_reason == snapshot("stop")
assert result.usage == snapshot(
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
)
def test_parse_response_null_content(self):
provider = LiteLLMProvider()
resp = _mock_response(content=None)
result = provider._parse_response(resp)
assert result.content == snapshot("")
def test_parse_response_no_usage(self):
provider = LiteLLMProvider()
resp = _mock_response()
resp.usage = None
result = provider._parse_response(resp)
assert result.usage is None
class TestLiteLLMProviderSync:
@pytest.fixture
def provider(self):
return LiteLLMProvider(model="openai/gpt-4o-mini")
def test_sync_generate(self, provider):
resp = _mock_response(content="Sync response")
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
result = provider.sync_generate("Hello")
assert result.content == snapshot("Sync response")
call_kwargs = mock_comp.call_args.kwargs
assert call_kwargs["drop_params"] is True
def test_sync_chat(self, provider):
resp = _mock_response(content="Chat response")
messages = [LLMMessage(role="user", content="Hi")]
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp):
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")
def test_sync_generate_with_system_prompt(self, provider):
resp = _mock_response(content="With system")
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
result = provider.sync_generate("Hello", system_prompt="Be brief")
assert result.content == snapshot("With system")
messages = mock_comp.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Be brief"
class TestLiteLLMProviderAsync:
@pytest.fixture
def provider(self):
return LiteLLMProvider(model="anthropic/claude-sonnet-4-6")
@pytest.mark.asyncio
async def test_generate(self, provider):
resp = _mock_response(content="Async response")
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp):
result = await provider.generate("Hello")
assert result.content == snapshot("Async response")
@pytest.mark.asyncio
async def test_chat(self, provider):
resp = _mock_response(content="Async chat")
messages = [LLMMessage(role="user", content="Hi")]
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp:
result = await provider.chat(messages)
assert result.content == snapshot("Async chat")
call_kwargs = mock_acomp.call_args.kwargs
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-6"
assert call_kwargs["drop_params"] is True
class TestLiteLLMProviderErrors:
@pytest.fixture
def provider(self):
return LiteLLMProvider()
def test_rate_limit_maps_to_llm_rate_limit_error(self, provider):
fake_exc = type("RateLimitError", (Exception,), {})()
fake_exc.__class__.__module__ = "litellm.exceptions"
fake_exc.__class__.__qualname__ = "RateLimitError"
with pytest.raises(LLMRateLimitError):
provider._handle_error(fake_exc)
def test_generic_error_maps_to_llm_provider_error(self, provider):
with pytest.raises(LLMProviderError):
provider._handle_error(Exception("something went wrong"))
def test_sync_chat_auth_error_raises_provider_error(self, provider):
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")):
with pytest.raises(LLMProviderError, match="Invalid API key"):
provider.sync_generate("test")
@pytest.mark.asyncio
async def test_async_chat_timeout_raises_provider_error(self, provider):
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
side_effect=Exception("Timeout: Request timed out")):
with pytest.raises(LLMProviderError, match="timed out"):
await provider.generate("test")
@pytest.mark.asyncio
async def test_async_chat_model_not_found_raises_provider_error(self, provider):
provider = LiteLLMProvider(model="bad/nonexistent-model")
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
side_effect=Exception("NotFoundError: Model not found")):
with pytest.raises(LLMProviderError, match="not found"):
await provider.generate("test")
class TestLiteLLMProviderFactory:
def test_factory_creates_litellm_provider(self):
from agentic_security.llm_providers.factory import create_provider
provider = create_provider("litellm")
assert isinstance(provider, LiteLLMProvider)
assert provider.model == snapshot("openai/gpt-4o-mini")
def test_factory_creates_with_custom_model(self):
from agentic_security.llm_providers.factory import create_provider
provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile")
assert provider.model == snapshot("groq/llama-3.3-70b-versatile")
def test_factory_lists_litellm(self):
from agentic_security.llm_providers.factory import list_providers
providers = list_providers()
assert "litellm" in providers