diff --git a/agentic_security/llm_providers/__init__.py b/agentic_security/llm_providers/__init__.py index dbe17d5..54663f7 100644 --- a/agentic_security/llm_providers/__init__.py +++ b/agentic_security/llm_providers/__init__.py @@ -7,6 +7,7 @@ from agentic_security.llm_providers.base import ( ) from agentic_security.llm_providers.openai_provider import OpenAIProvider from agentic_security.llm_providers.anthropic_provider import AnthropicProvider +from agentic_security.llm_providers.litellm_provider import LiteLLMProvider from agentic_security.llm_providers.factory import create_provider, get_provider_class __all__ = [ @@ -17,6 +18,7 @@ __all__ = [ "LLMRateLimitError", "OpenAIProvider", "AnthropicProvider", + "LiteLLMProvider", "create_provider", "get_provider_class", ] diff --git a/agentic_security/llm_providers/factory.py b/agentic_security/llm_providers/factory.py index 4736bef..577e5a1 100644 --- a/agentic_security/llm_providers/factory.py +++ b/agentic_security/llm_providers/factory.py @@ -14,9 +14,11 @@ def _ensure_registered() -> None: return from agentic_security.llm_providers.openai_provider import OpenAIProvider from agentic_security.llm_providers.anthropic_provider import AnthropicProvider + from agentic_security.llm_providers.litellm_provider import LiteLLMProvider _PROVIDERS["openai"] = OpenAIProvider _PROVIDERS["anthropic"] = AnthropicProvider + _PROVIDERS["litellm"] = LiteLLMProvider def register_provider(name: str, provider_class: type[BaseLLMProvider]) -> None: diff --git a/agentic_security/llm_providers/litellm_provider.py b/agentic_security/llm_providers/litellm_provider.py new file mode 100644 index 0000000..6dfab92 --- /dev/null +++ b/agentic_security/llm_providers/litellm_provider.py @@ -0,0 +1,112 @@ +"""LiteLLM provider — unified access to 100+ LLM backends.""" + +from typing import Any + +import litellm + +from agentic_security.llm_providers.base import ( + BaseLLMProvider, + LLMMessage, + LLMProviderError, + LLMRateLimitError, + LLMResponse, +) + + +class LiteLLMProvider(BaseLLMProvider): + """LLM provider using LiteLLM SDK for 100+ backends. + + Accepts any LiteLLM model string (e.g. ``openai/gpt-4o``, + ``anthropic/claude-sonnet-4-6``, ``groq/llama-3.3-70b-versatile``). + """ + + DEFAULT_MODEL = "openai/gpt-4o-mini" + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + api_base: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(model, **kwargs) + self._api_key = api_key + self._api_base = api_base + + def _call_kwargs(self) -> dict[str, Any]: + kwargs: dict[str, Any] = {"model": self.model, "drop_params": True} + if self._api_key: + kwargs["api_key"] = self._api_key + if self._api_base: + kwargs["api_base"] = self._api_base + return kwargs + + @classmethod + def get_supported_models(cls) -> list[str]: + return [ + "openai/gpt-4o", + "openai/gpt-4o-mini", + "anthropic/claude-sonnet-4-6", + "anthropic/claude-haiku-4-5", + "groq/llama-3.3-70b-versatile", + "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo", + ] + + def _messages_to_dicts(self, messages: list[LLMMessage]) -> list[dict[str, str]]: + return [{"role": m.role, "content": m.content} for m in messages] + + def _parse_response(self, response: Any) -> LLMResponse: + choice = response.choices[0] + usage = None + if response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + return LLMResponse( + content=choice.message.content or "", + model=getattr(response, "model", self.model), + finish_reason=choice.finish_reason, + usage=usage, + ) + + def _handle_error(self, e: Exception) -> None: + qualname = f"{type(e).__module__}.{type(e).__name__}" + if qualname == "litellm.exceptions.RateLimitError": + raise LLMRateLimitError(str(e)) from e + raise LLMProviderError(str(e)) from e + + async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + messages = [LLMMessage(role="user", content=prompt)] + if system_prompt := kwargs.pop("system_prompt", None): + messages.insert(0, LLMMessage(role="system", content=system_prompt)) + return await self.chat(messages, **kwargs) + + async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + try: + response = await litellm.acompletion( + messages=self._messages_to_dicts(messages), + **{**self._call_kwargs(), **kwargs}, + ) + return self._parse_response(response) + except Exception as e: + self._handle_error(e) + raise + + def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + messages = [LLMMessage(role="user", content=prompt)] + if system_prompt := kwargs.pop("system_prompt", None): + messages.insert(0, LLMMessage(role="system", content=system_prompt)) + return self.sync_chat(messages, **kwargs) + + def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + try: + response = litellm.completion( + messages=self._messages_to_dicts(messages), + **{**self._call_kwargs(), **kwargs}, + ) + return self._parse_response(response) + except Exception as e: + self._handle_error(e) + raise diff --git a/tests/unit/llm_providers/test_litellm_provider.py b/tests/unit/llm_providers/test_litellm_provider.py new file mode 100644 index 0000000..bc3f1df --- /dev/null +++ b/tests/unit/llm_providers/test_litellm_provider.py @@ -0,0 +1,233 @@ +"""Tests for LiteLLM provider.""" + +import pytest +from inline_snapshot import snapshot +from unittest.mock import MagicMock, AsyncMock, patch + +from agentic_security.llm_providers.litellm_provider import LiteLLMProvider +from agentic_security.llm_providers.base import ( + LLMMessage, + LLMProviderError, + LLMRateLimitError, +) + + +def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop", + prompt_tokens=10, completion_tokens=5, total_tokens=15): + resp = MagicMock() + resp.choices = [MagicMock()] + resp.choices[0].message.content = content + resp.choices[0].finish_reason = finish_reason + resp.model = model + resp.usage.prompt_tokens = prompt_tokens + resp.usage.completion_tokens = completion_tokens + resp.usage.total_tokens = total_tokens + return resp + + +class TestLiteLLMProviderInit: + def test_default_model(self): + provider = LiteLLMProvider() + assert provider.model == snapshot("openai/gpt-4o-mini") + + def test_custom_model(self): + provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6") + assert provider.model == snapshot("anthropic/claude-sonnet-4-6") + + def test_no_api_key_required(self): + provider = LiteLLMProvider() + assert provider._api_key is None + + def test_api_key_stored(self): + provider = LiteLLMProvider(api_key="sk-test") + assert provider._api_key == snapshot("sk-test") + + def test_api_base_stored(self): + provider = LiteLLMProvider(api_base="http://localhost:4000") + assert provider._api_base == snapshot("http://localhost:4000") + + +class TestLiteLLMProviderCallKwargs: + def test_drop_params_always_true(self): + provider = LiteLLMProvider() + kwargs = provider._call_kwargs() + assert kwargs["drop_params"] is True + + def test_api_key_forwarded_when_set(self): + provider = LiteLLMProvider(api_key="sk-test") + kwargs = provider._call_kwargs() + assert kwargs["api_key"] == snapshot("sk-test") + + def test_api_key_omitted_when_none(self): + provider = LiteLLMProvider() + kwargs = provider._call_kwargs() + assert "api_key" not in kwargs + + def test_api_base_forwarded_when_set(self): + provider = LiteLLMProvider(api_base="http://localhost:4000") + kwargs = provider._call_kwargs() + assert kwargs["api_base"] == snapshot("http://localhost:4000") + + def test_api_base_omitted_when_none(self): + provider = LiteLLMProvider() + kwargs = provider._call_kwargs() + assert "api_base" not in kwargs + + def test_model_in_kwargs(self): + provider = LiteLLMProvider(model="groq/llama-3.3-70b-versatile") + kwargs = provider._call_kwargs() + assert kwargs["model"] == snapshot("groq/llama-3.3-70b-versatile") + + +class TestLiteLLMProviderMethods: + def test_get_supported_models(self): + models = LiteLLMProvider.get_supported_models() + assert "openai/gpt-4o" in models + assert "anthropic/claude-sonnet-4-6" in models + + def test_messages_to_dicts(self): + provider = LiteLLMProvider() + messages = [ + LLMMessage(role="system", content="Be helpful"), + LLMMessage(role="user", content="Hello"), + ] + result = provider._messages_to_dicts(messages) + assert result == snapshot( + [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + ] + ) + + def test_parse_response(self): + provider = LiteLLMProvider() + resp = _mock_response(content="Hi!", model="openai/gpt-4o") + result = provider._parse_response(resp) + assert result.content == snapshot("Hi!") + assert result.model == snapshot("openai/gpt-4o") + assert result.finish_reason == snapshot("stop") + assert result.usage == snapshot( + {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + + def test_parse_response_null_content(self): + provider = LiteLLMProvider() + resp = _mock_response(content=None) + result = provider._parse_response(resp) + assert result.content == snapshot("") + + def test_parse_response_no_usage(self): + provider = LiteLLMProvider() + resp = _mock_response() + resp.usage = None + result = provider._parse_response(resp) + assert result.usage is None + + +class TestLiteLLMProviderSync: + @pytest.fixture + def provider(self): + return LiteLLMProvider(model="openai/gpt-4o-mini") + + def test_sync_generate(self, provider): + resp = _mock_response(content="Sync response") + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp: + result = provider.sync_generate("Hello") + assert result.content == snapshot("Sync response") + call_kwargs = mock_comp.call_args.kwargs + assert call_kwargs["drop_params"] is True + + def test_sync_chat(self, provider): + resp = _mock_response(content="Chat response") + messages = [LLMMessage(role="user", content="Hi")] + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp): + result = provider.sync_chat(messages) + assert result.content == snapshot("Chat response") + + def test_sync_generate_with_system_prompt(self, provider): + resp = _mock_response(content="With system") + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp: + result = provider.sync_generate("Hello", system_prompt="Be brief") + assert result.content == snapshot("With system") + messages = mock_comp.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be brief" + + +class TestLiteLLMProviderAsync: + @pytest.fixture + def provider(self): + return LiteLLMProvider(model="anthropic/claude-sonnet-4-6") + + @pytest.mark.asyncio + async def test_generate(self, provider): + resp = _mock_response(content="Async response") + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp): + result = await provider.generate("Hello") + assert result.content == snapshot("Async response") + + @pytest.mark.asyncio + async def test_chat(self, provider): + resp = _mock_response(content="Async chat") + messages = [LLMMessage(role="user", content="Hi")] + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp: + result = await provider.chat(messages) + assert result.content == snapshot("Async chat") + call_kwargs = mock_acomp.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-sonnet-4-6" + assert call_kwargs["drop_params"] is True + + +class TestLiteLLMProviderErrors: + @pytest.fixture + def provider(self): + return LiteLLMProvider() + + def test_rate_limit_maps_to_llm_rate_limit_error(self, provider): + fake_exc = type("RateLimitError", (Exception,), {})() + fake_exc.__class__.__module__ = "litellm.exceptions" + fake_exc.__class__.__qualname__ = "RateLimitError" + with pytest.raises(LLMRateLimitError): + provider._handle_error(fake_exc) + + def test_generic_error_maps_to_llm_provider_error(self, provider): + with pytest.raises(LLMProviderError): + provider._handle_error(Exception("something went wrong")) + + def test_sync_chat_auth_error_raises_provider_error(self, provider): + with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")): + with pytest.raises(LLMProviderError, match="Invalid API key"): + provider.sync_generate("test") + + @pytest.mark.asyncio + async def test_async_chat_timeout_raises_provider_error(self, provider): + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, + side_effect=Exception("Timeout: Request timed out")): + with pytest.raises(LLMProviderError, match="timed out"): + await provider.generate("test") + + @pytest.mark.asyncio + async def test_async_chat_model_not_found_raises_provider_error(self, provider): + provider = LiteLLMProvider(model="bad/nonexistent-model") + with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, + side_effect=Exception("NotFoundError: Model not found")): + with pytest.raises(LLMProviderError, match="not found"): + await provider.generate("test") + + +class TestLiteLLMProviderFactory: + def test_factory_creates_litellm_provider(self): + from agentic_security.llm_providers.factory import create_provider + provider = create_provider("litellm") + assert isinstance(provider, LiteLLMProvider) + assert provider.model == snapshot("openai/gpt-4o-mini") + + def test_factory_creates_with_custom_model(self): + from agentic_security.llm_providers.factory import create_provider + provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile") + assert provider.model == snapshot("groq/llama-3.3-70b-versatile") + + def test_factory_lists_litellm(self): + from agentic_security.llm_providers.factory import list_providers + providers = list_providers() + assert "litellm" in providers