diff --git a/agentic_security/llm_providers/litellm_provider.py b/agentic_security/llm_providers/litellm_provider.py index 0cf9ef3..6dfab92 100644 --- a/agentic_security/llm_providers/litellm_provider.py +++ b/agentic_security/llm_providers/litellm_provider.py @@ -2,6 +2,8 @@ from typing import Any +import litellm + from agentic_security.llm_providers.base import ( BaseLLMProvider, LLMMessage, @@ -82,8 +84,6 @@ class LiteLLMProvider(BaseLLMProvider): return await self.chat(messages, **kwargs) async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: - import litellm - try: response = await litellm.acompletion( messages=self._messages_to_dicts(messages), @@ -101,8 +101,6 @@ class LiteLLMProvider(BaseLLMProvider): return self.sync_chat(messages, **kwargs) def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: - import litellm - try: response = litellm.completion( messages=self._messages_to_dicts(messages), diff --git a/tests/unit/llm_providers/test_litellm_provider.py b/tests/unit/llm_providers/test_litellm_provider.py new file mode 100644 index 0000000..bc3f1df --- /dev/null +++ b/tests/unit/llm_providers/test_litellm_provider.py @@ -0,0 +1,233 @@ +"""Tests for LiteLLM provider.""" + +import pytest +from inline_snapshot import snapshot +from unittest.mock import MagicMock, AsyncMock, patch + +from agentic_security.llm_providers.litellm_provider import LiteLLMProvider +from agentic_security.llm_providers.base import ( + LLMMessage, + LLMProviderError, + LLMRateLimitError, +) + + +def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop", + prompt_tokens=10, completion_tokens=5, total_tokens=15): + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].message.content = content + resp.choices[0].finish_reason = finish_reason + resp.model = model + resp.usage.prompt_tokens = prompt_tokens + resp.usage.completion_tokens = completion_tokens + resp.usage.total_tokens = total_tokens + return resp + + +class TestLiteLLMProviderInit: + def test_default_model(self): + provider = LiteLLMProvider() + assert provider.model == snapshot("openai/gpt-4o-mini") + + def test_custom_model(self): + provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6") + assert provider.model == snapshot("anthropic/claude-sonnet-4-6") + + def test_no_api_key_required(self): + provider = LiteLLMProvider() + assert provider._api_key is None + + def test_api_key_stored(self): + provider = LiteLLMProvider(api_key="sk-test") + assert provider._api_key == snapshot("sk-test") + + def test_api_base_stored(self): + provider = LiteLLMProvider(api_base="http://localhost:4000") + assert provider._api_base == snapshot("http://localhost:4000") + + +class TestLiteLLMProviderCallKwargs: + def test_drop_params_always_true(self): + provider = LiteLLMProvider() + kwargs = provider._call_kwargs() + assert kwargs["drop_params"] is True + + def test_api_key_forwarded_when_set(self): + provider = LiteLLMProvider(api_key="sk-test") + kwargs = provider._call_kwargs() + assert kwargs["api_key"] == snapshot("sk-test") + + def test_api_key_omitted_when_none(self): + provider = LiteLLMProvider() + kwargs = provider._call_kwargs() + assert "api_key" not in kwargs + + def test_api_base_forwarded_when_set(self): + provider = LiteLLMProvider(api_base="http://localhost:4000") + kwargs = provider._call_kwargs() + assert kwargs["api_base"] == snapshot("http://localhost:4000") + + def test_api_base_omitted_when_none(self): + provider = LiteLLMProvider() + kwargs = provider._call_kwargs() + assert "api_base" not in kwargs + + def test_model_in_kwargs(self): + provider = LiteLLMProvider(model="groq/llama-3.3-70b-versatile") + kwargs = provider._call_kwargs() + assert kwargs["model"] == snapshot("groq/llama-3.3-70b-versatile") + + +class TestLiteLLMProviderMethods: + def test_get_supported_models(self): + models = LiteLLMProvider.get_supported_models() + assert "openai/gpt-4o" in models + assert "anthropic/claude-sonnet-4-6" in models + + def test_messages_to_dicts(self): + provider = LiteLLMProvider() + messages = [ + LLMMessage(role="system", content="Be helpful"), + LLMMessage(role="user", content="Hello"), + ] + result = provider._messages_to_dicts(messages) + assert result == snapshot( + [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + ] + ) + + def test_parse_response(self): + provider = LiteLLMProvider() + resp = _mock_response(content="Hi!", model="openai/gpt-4o") + result = provider._parse_response(resp) + assert result.content == snapshot("Hi!") + assert result.model == snapshot("openai/gpt-4o") + assert result.finish_reason == snapshot("stop") + assert result.usage == snapshot( + {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + + def test_parse_response_null_content(self): + provider = LiteLLMProvider() + resp = _mock_response(content=None) + result = provider._parse_response(resp) + assert result.content == snapshot("") + + def test_parse_response_no_usage(self): + provider = LiteLLMProvider() + resp = _mock_response() + resp.usage = None + result = provider._parse_response(resp) + assert result.usage is None + + +class TestLiteLLMProviderSync: + @pytest.fixture + def provider(self): + return LiteLLMProvider(model="openai/gpt-4o-mini") + + def test_sync_generate(self, provider): + resp = _mock_response(content="Sync response") + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp: + result = provider.sync_generate("Hello") + assert result.content == snapshot("Sync response") + call_kwargs = mock_comp.call_args.kwargs + assert call_kwargs["drop_params"] is True + + def test_sync_chat(self, provider): + resp = _mock_response(content="Chat response") + messages = [LLMMessage(role="user", content="Hi")] + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp): + result = provider.sync_chat(messages) + assert result.content == snapshot("Chat response") + + def test_sync_generate_with_system_prompt(self, provider): + resp = _mock_response(content="With system") + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp: + result = provider.sync_generate("Hello", system_prompt="Be brief") + assert result.content == snapshot("With system") + messages = mock_comp.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be brief" + + +class TestLiteLLMProviderAsync: + @pytest.fixture + def provider(self): + return LiteLLMProvider(model="anthropic/claude-sonnet-4-6") + + @pytest.mark.asyncio + async def test_generate(self, provider): + resp = _mock_response(content="Async response") + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp): + result = await provider.generate("Hello") + assert result.content == snapshot("Async response") + + @pytest.mark.asyncio + async def test_chat(self, provider): + resp = _mock_response(content="Async chat") + messages = [LLMMessage(role="user", content="Hi")] + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp: + result = await provider.chat(messages) + assert result.content == snapshot("Async chat") + call_kwargs = mock_acomp.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-sonnet-4-6" + assert call_kwargs["drop_params"] is True + + +class TestLiteLLMProviderErrors: + @pytest.fixture + def provider(self): + return LiteLLMProvider() + + def test_rate_limit_maps_to_llm_rate_limit_error(self, provider): + fake_exc = type("RateLimitError", (Exception,), {})() + fake_exc.__class__.__module__ = "litellm.exceptions" + fake_exc.__class__.__qualname__ = "RateLimitError" + with pytest.raises(LLMRateLimitError): + provider._handle_error(fake_exc) + + def test_generic_error_maps_to_llm_provider_error(self, provider): + with pytest.raises(LLMProviderError): + provider._handle_error(Exception("something went wrong")) + + def test_sync_chat_auth_error_raises_provider_error(self, provider): + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")): + with pytest.raises(LLMProviderError, match="Invalid API key"): + provider.sync_generate("test") + + @pytest.mark.asyncio + async def test_async_chat_timeout_raises_provider_error(self, provider): + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, + side_effect=Exception("Timeout: Request timed out")): + with pytest.raises(LLMProviderError, match="timed out"): + await provider.generate("test") + + @pytest.mark.asyncio + async def test_async_chat_model_not_found_raises_provider_error(self, provider): + provider = LiteLLMProvider(model="bad/nonexistent-model") + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, + side_effect=Exception("NotFoundError: Model not found")): + with pytest.raises(LLMProviderError, match="not found"): + await provider.generate("test") + + +class TestLiteLLMProviderFactory: + def test_factory_creates_litellm_provider(self): + from agentic_security.llm_providers.factory import create_provider + provider = create_provider("litellm") + assert isinstance(provider, LiteLLMProvider) + assert provider.model == snapshot("openai/gpt-4o-mini") + + def test_factory_creates_with_custom_model(self): + from agentic_security.llm_providers.factory import create_provider + provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile") + assert provider.model == snapshot("groq/llama-3.3-70b-versatile") + + def test_factory_lists_litellm(self): + from agentic_security.llm_providers.factory import list_providers + providers = list_providers() + assert "litellm" in providers