mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-23 21:59:57 +02:00
Merge pull request #299 from RheagalFire/feat/add-litellm-provider
feat: add LiteLLM as provider for 100+ LLM backends
This commit is contained in:
@@ -7,6 +7,7 @@ from agentic_security.llm_providers.base import (
|
||||
)
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
from agentic_security.llm_providers.litellm_provider import LiteLLMProvider
|
||||
from agentic_security.llm_providers.factory import create_provider, get_provider_class
|
||||
|
||||
__all__ = [
|
||||
@@ -17,6 +18,7 @@ __all__ = [
|
||||
"LLMRateLimitError",
|
||||
"OpenAIProvider",
|
||||
"AnthropicProvider",
|
||||
"LiteLLMProvider",
|
||||
"create_provider",
|
||||
"get_provider_class",
|
||||
]
|
||||
|
||||
@@ -14,9 +14,11 @@ def _ensure_registered() -> None:
|
||||
return
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
from agentic_security.llm_providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
_PROVIDERS["openai"] = OpenAIProvider
|
||||
_PROVIDERS["anthropic"] = AnthropicProvider
|
||||
_PROVIDERS["litellm"] = LiteLLMProvider
|
||||
|
||||
|
||||
def register_provider(name: str, provider_class: type[BaseLLMProvider]) -> None:
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
"""LiteLLM provider — unified access to 100+ LLM backends."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
|
||||
class LiteLLMProvider(BaseLLMProvider):
|
||||
"""LLM provider using LiteLLM SDK for 100+ backends.
|
||||
|
||||
Accepts any LiteLLM model string (e.g. ``openai/gpt-4o``,
|
||||
``anthropic/claude-sonnet-4-6``, ``groq/llama-3.3-70b-versatile``).
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "openai/gpt-4o-mini"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = DEFAULT_MODEL,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, **kwargs)
|
||||
self._api_key = api_key
|
||||
self._api_base = api_base
|
||||
|
||||
def _call_kwargs(self) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {"model": self.model, "drop_params": True}
|
||||
if self._api_key:
|
||||
kwargs["api_key"] = self._api_key
|
||||
if self._api_base:
|
||||
kwargs["api_base"] = self._api_base
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def get_supported_models(cls) -> list[str]:
|
||||
return [
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"anthropic/claude-sonnet-4-6",
|
||||
"anthropic/claude-haiku-4-5",
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
"together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
]
|
||||
|
||||
def _messages_to_dicts(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
|
||||
return [{"role": m.role, "content": m.content} for m in messages]
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
choice = response.choices[0]
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
return LLMResponse(
|
||||
content=choice.message.content or "",
|
||||
model=getattr(response, "model", self.model),
|
||||
finish_reason=choice.finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_error(self, e: Exception) -> None:
|
||||
qualname = f"{type(e).__module__}.{type(e).__name__}"
|
||||
if qualname == "litellm.exceptions.RateLimitError":
|
||||
raise LLMRateLimitError(str(e)) from e
|
||||
raise LLMProviderError(str(e)) from e
|
||||
|
||||
async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
if system_prompt := kwargs.pop("system_prompt", None):
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
return await self.chat(messages, **kwargs)
|
||||
|
||||
async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
messages=self._messages_to_dicts(messages),
|
||||
**{**self._call_kwargs(), **kwargs},
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
self._handle_error(e)
|
||||
raise
|
||||
|
||||
def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
if system_prompt := kwargs.pop("system_prompt", None):
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
return self.sync_chat(messages, **kwargs)
|
||||
|
||||
def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
messages=self._messages_to_dicts(messages),
|
||||
**{**self._call_kwargs(), **kwargs},
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
self._handle_error(e)
|
||||
raise
|
||||
@@ -0,0 +1,233 @@
|
||||
"""Tests for LiteLLM provider."""
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from agentic_security.llm_providers.litellm_provider import LiteLLMProvider
|
||||
from agentic_security.llm_providers.base import (
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop",
|
||||
prompt_tokens=10, completion_tokens=5, total_tokens=15):
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = content
|
||||
resp.choices[0].finish_reason = finish_reason
|
||||
resp.model = model
|
||||
resp.usage.prompt_tokens = prompt_tokens
|
||||
resp.usage.completion_tokens = completion_tokens
|
||||
resp.usage.total_tokens = total_tokens
|
||||
return resp
|
||||
|
||||
|
||||
class TestLiteLLMProviderInit:
|
||||
def test_default_model(self):
|
||||
provider = LiteLLMProvider()
|
||||
assert provider.model == snapshot("openai/gpt-4o-mini")
|
||||
|
||||
def test_custom_model(self):
|
||||
provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6")
|
||||
assert provider.model == snapshot("anthropic/claude-sonnet-4-6")
|
||||
|
||||
def test_no_api_key_required(self):
|
||||
provider = LiteLLMProvider()
|
||||
assert provider._api_key is None
|
||||
|
||||
def test_api_key_stored(self):
|
||||
provider = LiteLLMProvider(api_key="sk-test")
|
||||
assert provider._api_key == snapshot("sk-test")
|
||||
|
||||
def test_api_base_stored(self):
|
||||
provider = LiteLLMProvider(api_base="http://localhost:4000")
|
||||
assert provider._api_base == snapshot("http://localhost:4000")
|
||||
|
||||
|
||||
class TestLiteLLMProviderCallKwargs:
|
||||
def test_drop_params_always_true(self):
|
||||
provider = LiteLLMProvider()
|
||||
kwargs = provider._call_kwargs()
|
||||
assert kwargs["drop_params"] is True
|
||||
|
||||
def test_api_key_forwarded_when_set(self):
|
||||
provider = LiteLLMProvider(api_key="sk-test")
|
||||
kwargs = provider._call_kwargs()
|
||||
assert kwargs["api_key"] == snapshot("sk-test")
|
||||
|
||||
def test_api_key_omitted_when_none(self):
|
||||
provider = LiteLLMProvider()
|
||||
kwargs = provider._call_kwargs()
|
||||
assert "api_key" not in kwargs
|
||||
|
||||
def test_api_base_forwarded_when_set(self):
|
||||
provider = LiteLLMProvider(api_base="http://localhost:4000")
|
||||
kwargs = provider._call_kwargs()
|
||||
assert kwargs["api_base"] == snapshot("http://localhost:4000")
|
||||
|
||||
def test_api_base_omitted_when_none(self):
|
||||
provider = LiteLLMProvider()
|
||||
kwargs = provider._call_kwargs()
|
||||
assert "api_base" not in kwargs
|
||||
|
||||
def test_model_in_kwargs(self):
|
||||
provider = LiteLLMProvider(model="groq/llama-3.3-70b-versatile")
|
||||
kwargs = provider._call_kwargs()
|
||||
assert kwargs["model"] == snapshot("groq/llama-3.3-70b-versatile")
|
||||
|
||||
|
||||
class TestLiteLLMProviderMethods:
|
||||
def test_get_supported_models(self):
|
||||
models = LiteLLMProvider.get_supported_models()
|
||||
assert "openai/gpt-4o" in models
|
||||
assert "anthropic/claude-sonnet-4-6" in models
|
||||
|
||||
def test_messages_to_dicts(self):
|
||||
provider = LiteLLMProvider()
|
||||
messages = [
|
||||
LLMMessage(role="system", content="Be helpful"),
|
||||
LLMMessage(role="user", content="Hello"),
|
||||
]
|
||||
result = provider._messages_to_dicts(messages)
|
||||
assert result == snapshot(
|
||||
[
|
||||
{"role": "system", "content": "Be helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_parse_response(self):
|
||||
provider = LiteLLMProvider()
|
||||
resp = _mock_response(content="Hi!", model="openai/gpt-4o")
|
||||
result = provider._parse_response(resp)
|
||||
assert result.content == snapshot("Hi!")
|
||||
assert result.model == snapshot("openai/gpt-4o")
|
||||
assert result.finish_reason == snapshot("stop")
|
||||
assert result.usage == snapshot(
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
)
|
||||
|
||||
def test_parse_response_null_content(self):
|
||||
provider = LiteLLMProvider()
|
||||
resp = _mock_response(content=None)
|
||||
result = provider._parse_response(resp)
|
||||
assert result.content == snapshot("")
|
||||
|
||||
def test_parse_response_no_usage(self):
|
||||
provider = LiteLLMProvider()
|
||||
resp = _mock_response()
|
||||
resp.usage = None
|
||||
result = provider._parse_response(resp)
|
||||
assert result.usage is None
|
||||
|
||||
|
||||
class TestLiteLLMProviderSync:
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
return LiteLLMProvider(model="openai/gpt-4o-mini")
|
||||
|
||||
def test_sync_generate(self, provider):
|
||||
resp = _mock_response(content="Sync response")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
|
||||
result = provider.sync_generate("Hello")
|
||||
assert result.content == snapshot("Sync response")
|
||||
call_kwargs = mock_comp.call_args.kwargs
|
||||
assert call_kwargs["drop_params"] is True
|
||||
|
||||
def test_sync_chat(self, provider):
|
||||
resp = _mock_response(content="Chat response")
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp):
|
||||
result = provider.sync_chat(messages)
|
||||
assert result.content == snapshot("Chat response")
|
||||
|
||||
def test_sync_generate_with_system_prompt(self, provider):
|
||||
resp = _mock_response(content="With system")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
|
||||
result = provider.sync_generate("Hello", system_prompt="Be brief")
|
||||
assert result.content == snapshot("With system")
|
||||
messages = mock_comp.call_args.kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[0]["content"] == "Be brief"
|
||||
|
||||
|
||||
class TestLiteLLMProviderAsync:
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
return LiteLLMProvider(model="anthropic/claude-sonnet-4-6")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate(self, provider):
|
||||
resp = _mock_response(content="Async response")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp):
|
||||
result = await provider.generate("Hello")
|
||||
assert result.content == snapshot("Async response")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat(self, provider):
|
||||
resp = _mock_response(content="Async chat")
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp:
|
||||
result = await provider.chat(messages)
|
||||
assert result.content == snapshot("Async chat")
|
||||
call_kwargs = mock_acomp.call_args.kwargs
|
||||
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-6"
|
||||
assert call_kwargs["drop_params"] is True
|
||||
|
||||
|
||||
class TestLiteLLMProviderErrors:
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
return LiteLLMProvider()
|
||||
|
||||
def test_rate_limit_maps_to_llm_rate_limit_error(self, provider):
|
||||
fake_exc = type("RateLimitError", (Exception,), {})()
|
||||
fake_exc.__class__.__module__ = "litellm.exceptions"
|
||||
fake_exc.__class__.__qualname__ = "RateLimitError"
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(fake_exc)
|
||||
|
||||
def test_generic_error_maps_to_llm_provider_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
provider._handle_error(Exception("something went wrong"))
|
||||
|
||||
def test_sync_chat_auth_error_raises_provider_error(self, provider):
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")):
|
||||
with pytest.raises(LLMProviderError, match="Invalid API key"):
|
||||
provider.sync_generate("test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_timeout_raises_provider_error(self, provider):
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
|
||||
side_effect=Exception("Timeout: Request timed out")):
|
||||
with pytest.raises(LLMProviderError, match="timed out"):
|
||||
await provider.generate("test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_model_not_found_raises_provider_error(self, provider):
|
||||
provider = LiteLLMProvider(model="bad/nonexistent-model")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
|
||||
side_effect=Exception("NotFoundError: Model not found")):
|
||||
with pytest.raises(LLMProviderError, match="not found"):
|
||||
await provider.generate("test")
|
||||
|
||||
|
||||
class TestLiteLLMProviderFactory:
|
||||
def test_factory_creates_litellm_provider(self):
|
||||
from agentic_security.llm_providers.factory import create_provider
|
||||
provider = create_provider("litellm")
|
||||
assert isinstance(provider, LiteLLMProvider)
|
||||
assert provider.model == snapshot("openai/gpt-4o-mini")
|
||||
|
||||
def test_factory_creates_with_custom_model(self):
|
||||
from agentic_security.llm_providers.factory import create_provider
|
||||
provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile")
|
||||
assert provider.model == snapshot("groq/llama-3.3-70b-versatile")
|
||||
|
||||
def test_factory_lists_litellm(self):
|
||||
from agentic_security.llm_providers.factory import list_providers
|
||||
providers = list_providers()
|
||||
assert "litellm" in providers
|
||||
Reference in New Issue
Block a user