mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-23 21:59:57 +02:00
feat: US-004 - Unified LLM Provider Abstraction
Create unified provider abstraction layer for direct LLM integrations beyond HTTP specs, inspired by FuzzyAI's comprehensive provider system. - Add BaseLLMProvider abstract class with standard interface (generate, chat, sync_generate, sync_chat methods) - Implement OpenAIProvider supporting chat completions API - Implement AnthropicProvider supporting messages API - Create provider factory for instantiation by name (create_provider, get_provider_class) - Add 60 unit tests covering all provider implementations
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMMessage,
|
||||
LLMResponse,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
from agentic_security.llm_providers.factory import create_provider, get_provider_class
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMProvider",
|
||||
"LLMMessage",
|
||||
"LLMResponse",
|
||||
"LLMProviderError",
|
||||
"LLMRateLimitError",
|
||||
"OpenAIProvider",
|
||||
"AnthropicProvider",
|
||||
"create_provider",
|
||||
"get_provider_class",
|
||||
]
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Anthropic LLM provider implementation."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic API provider supporting messages API."""
|
||||
|
||||
DEFAULT_MODEL = "claude-3-haiku-20240307"
|
||||
API_KEY_ENV = "ANTHROPIC_API_KEY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = DEFAULT_MODEL,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, **kwargs)
|
||||
self.api_key = api_key or os.environ.get(self.API_KEY_ENV)
|
||||
if not self.api_key:
|
||||
raise LLMProviderError(f"{self.API_KEY_ENV} not set")
|
||||
self.base_url = base_url
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
if self._client is None:
|
||||
import anthropic
|
||||
kwargs: dict[str, Any] = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
kwargs["base_url"] = self.base_url
|
||||
self._client = anthropic.Anthropic(**kwargs)
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
import anthropic
|
||||
kwargs: dict[str, Any] = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
kwargs["base_url"] = self.base_url
|
||||
self._async_client = anthropic.AsyncAnthropic(**kwargs)
|
||||
return self._async_client
|
||||
|
||||
@classmethod
|
||||
def get_supported_models(cls) -> list[str]:
|
||||
return [
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-latest",
|
||||
]
|
||||
|
||||
def _messages_to_dicts(
|
||||
self, messages: list[LLMMessage]
|
||||
) -> tuple[str | None, list[dict[str, str]]]:
|
||||
"""Extract system prompt and convert messages to Anthropic format."""
|
||||
system_prompt = None
|
||||
chat_messages = []
|
||||
for m in messages:
|
||||
if m.role == "system":
|
||||
system_prompt = m.content
|
||||
else:
|
||||
chat_messages.append({"role": m.role, "content": m.content})
|
||||
return system_prompt, chat_messages
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
content = ""
|
||||
if response.content:
|
||||
block = response.content[0]
|
||||
if hasattr(block, "text"):
|
||||
content = block.text
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = {
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
}
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model,
|
||||
finish_reason=response.stop_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_error(self, e: Exception) -> None:
|
||||
import anthropic
|
||||
if isinstance(e, anthropic.RateLimitError):
|
||||
raise LLMRateLimitError(str(e)) from e
|
||||
if isinstance(e, anthropic.APIError):
|
||||
raise LLMProviderError(str(e)) from e
|
||||
raise LLMProviderError(str(e)) from e
|
||||
|
||||
async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
if system_prompt := kwargs.pop("system_prompt", None):
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
return await self.chat(messages, **kwargs)
|
||||
|
||||
async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
client = self._get_async_client()
|
||||
system_prompt, chat_messages = self._messages_to_dicts(messages)
|
||||
create_kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": chat_messages,
|
||||
"max_tokens": kwargs.pop("max_tokens", 1024),
|
||||
}
|
||||
if system_prompt:
|
||||
create_kwargs["system"] = system_prompt
|
||||
create_kwargs.update(kwargs)
|
||||
try:
|
||||
response = await client.messages.create(**create_kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
self._handle_error(e)
|
||||
raise # unreachable, but satisfies type checker
|
||||
|
||||
def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
if system_prompt := kwargs.pop("system_prompt", None):
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
return self.sync_chat(messages, **kwargs)
|
||||
|
||||
def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
client = self._get_client()
|
||||
system_prompt, chat_messages = self._messages_to_dicts(messages)
|
||||
create_kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": chat_messages,
|
||||
"max_tokens": kwargs.pop("max_tokens", 1024),
|
||||
}
|
||||
if system_prompt:
|
||||
create_kwargs["system"] = system_prompt
|
||||
create_kwargs.update(kwargs)
|
||||
try:
|
||||
response = client.messages.create(**create_kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
self._handle_error(e)
|
||||
raise # unreachable, but satisfies type checker
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._async_client:
|
||||
await self._async_client.close()
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Base LLM provider abstraction for unified API access.
|
||||
|
||||
Inspired by FuzzyAI's provider architecture, providing a simple interface
|
||||
for both sync and async LLM interactions.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LLMProviderError(Exception):
|
||||
"""Base exception for LLM provider errors."""
|
||||
|
||||
|
||||
class LLMRateLimitError(LLMProviderError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMMessage:
|
||||
"""A message in a chat conversation."""
|
||||
role: str # "system", "user", or "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
usage: dict[str, int] | None = None
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers.
|
||||
|
||||
Subclasses must implement generate() and chat() methods for both
|
||||
sync and async variants.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs: Any) -> None:
|
||||
self.model = model
|
||||
self._extra = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
"""Generate a response from a single prompt."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
"""Generate a response from a chat conversation."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
"""Synchronous version of generate()."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
"""Synchronous version of chat()."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_supported_models(cls) -> list[str]:
|
||||
"""Return list of supported model names."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close any open connections. Override if cleanup is needed."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(model={self.model!r})"
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Factory for creating LLM provider instances."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError
|
||||
|
||||
# Provider registry mapping name to class
|
||||
_PROVIDERS: dict[str, type[BaseLLMProvider]] = {}
|
||||
|
||||
|
||||
def _ensure_registered() -> None:
|
||||
"""Lazy registration of built-in providers."""
|
||||
if _PROVIDERS:
|
||||
return
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
_PROVIDERS["openai"] = OpenAIProvider
|
||||
_PROVIDERS["anthropic"] = AnthropicProvider
|
||||
|
||||
|
||||
def register_provider(name: str, provider_class: type[BaseLLMProvider]) -> None:
|
||||
"""Register a custom provider class."""
|
||||
_ensure_registered()
|
||||
_PROVIDERS[name.lower()] = provider_class
|
||||
|
||||
|
||||
def get_provider_class(name: str) -> type[BaseLLMProvider]:
|
||||
"""Get provider class by name."""
|
||||
_ensure_registered()
|
||||
name_lower = name.lower()
|
||||
if name_lower not in _PROVIDERS:
|
||||
available = ", ".join(sorted(_PROVIDERS.keys()))
|
||||
raise LLMProviderError(f"Unknown provider: {name}. Available: {available}")
|
||||
return _PROVIDERS[name_lower]
|
||||
|
||||
|
||||
def list_providers() -> list[str]:
|
||||
"""List all available provider names."""
|
||||
_ensure_registered()
|
||||
return sorted(_PROVIDERS.keys())
|
||||
|
||||
|
||||
def create_provider(
|
||||
name: str,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseLLMProvider:
|
||||
"""Create a provider instance by name.
|
||||
|
||||
Args:
|
||||
name: Provider name ("openai", "anthropic", etc.)
|
||||
model: Model name. If None, uses provider's default.
|
||||
**kwargs: Additional arguments passed to provider constructor.
|
||||
|
||||
Returns:
|
||||
Configured provider instance.
|
||||
|
||||
Raises:
|
||||
LLMProviderError: If provider name is unknown.
|
||||
"""
|
||||
provider_class = get_provider_class(name)
|
||||
if model is None:
|
||||
model = getattr(provider_class, "DEFAULT_MODEL", None)
|
||||
if model is None:
|
||||
raise LLMProviderError(f"No model specified and {name} has no default")
|
||||
return provider_class(model=model, **kwargs)
|
||||
@@ -0,0 +1,126 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI API provider supporting chat completions."""
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o-mini"
|
||||
API_KEY_ENV = "OPENAI_API_KEY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = DEFAULT_MODEL,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, **kwargs)
|
||||
self.api_key = api_key or os.environ.get(self.API_KEY_ENV)
|
||||
if not self.api_key:
|
||||
raise LLMProviderError(f"{self.API_KEY_ENV} not set")
|
||||
self.base_url = base_url
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
if self._client is None:
|
||||
import openai
|
||||
self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
import openai
|
||||
self._async_client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
return self._async_client
|
||||
|
||||
@classmethod
|
||||
def get_supported_models(cls) -> list[str]:
|
||||
return [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
def _messages_to_dicts(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
|
||||
return [{"role": m.role, "content": m.content} for m in messages]
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
choice = response.choices[0]
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
return LLMResponse(
|
||||
content=choice.message.content or "",
|
||||
model=response.model,
|
||||
finish_reason=choice.finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_error(self, e: Exception) -> None:
|
||||
import openai
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
raise LLMRateLimitError(str(e)) from e
|
||||
raise LLMProviderError(str(e)) from e
|
||||
|
||||
async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
if system_prompt := kwargs.pop("system_prompt", None):
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
return await self.chat(messages, **kwargs)
|
||||
|
||||
async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
client = self._get_async_client()
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self._messages_to_dicts(messages),
|
||||
**kwargs,
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
self._handle_error(e)
|
||||
raise # unreachable, but satisfies type checker
|
||||
|
||||
def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
if system_prompt := kwargs.pop("system_prompt", None):
|
||||
messages.insert(0, LLMMessage(role="system", content=system_prompt))
|
||||
return self.sync_chat(messages, **kwargs)
|
||||
|
||||
def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
|
||||
client = self._get_client()
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self._messages_to_dicts(messages),
|
||||
**kwargs,
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
self._handle_error(e)
|
||||
raise # unreachable, but satisfies type checker
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._async_client:
|
||||
await self._async_client.close()
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Tests for Anthropic provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
|
||||
|
||||
|
||||
class TestAnthropicProviderInit:
|
||||
def test_requires_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with pytest.raises(LLMProviderError) as exc:
|
||||
AnthropicProvider()
|
||||
assert "ANTHROPIC_API_KEY" in str(exc.value)
|
||||
|
||||
def test_accepts_api_key_directly(self, monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
provider = AnthropicProvider(api_key="test-key")
|
||||
assert provider.api_key == snapshot("test-key")
|
||||
|
||||
def test_uses_env_api_key(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "env-key")
|
||||
provider = AnthropicProvider()
|
||||
assert provider.api_key == snapshot("env-key")
|
||||
|
||||
def test_default_model(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
provider = AnthropicProvider()
|
||||
assert provider.model == snapshot("claude-3-haiku-20240307")
|
||||
|
||||
def test_custom_model(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
provider = AnthropicProvider(model="claude-3-5-sonnet-latest")
|
||||
assert provider.model == snapshot("claude-3-5-sonnet-latest")
|
||||
|
||||
def test_custom_base_url(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
provider = AnthropicProvider(base_url="https://custom.api.com")
|
||||
assert provider.base_url == snapshot("https://custom.api.com")
|
||||
|
||||
|
||||
class TestAnthropicProviderMethods:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
return AnthropicProvider()
|
||||
|
||||
def test_get_supported_models(self, provider):
|
||||
models = provider.get_supported_models()
|
||||
assert "claude-3-haiku-20240307" in models
|
||||
assert "claude-3-5-sonnet-latest" in models
|
||||
|
||||
def test_messages_to_dicts_simple(self, provider):
|
||||
messages = [LLMMessage(role="user", content="Hello")]
|
||||
system, chat = provider._messages_to_dicts(messages)
|
||||
assert system is None
|
||||
assert chat == snapshot([{"role": "user", "content": "Hello"}])
|
||||
|
||||
def test_messages_to_dicts_with_system(self, provider):
|
||||
messages = [
|
||||
LLMMessage(role="system", content="Be helpful"),
|
||||
LLMMessage(role="user", content="Hello"),
|
||||
]
|
||||
system, chat = provider._messages_to_dicts(messages)
|
||||
assert system == snapshot("Be helpful")
|
||||
assert chat == snapshot([{"role": "user", "content": "Hello"}])
|
||||
|
||||
def test_messages_to_dicts_multi_turn(self, provider):
|
||||
messages = [
|
||||
LLMMessage(role="user", content="Hi"),
|
||||
LLMMessage(role="assistant", content="Hello!"),
|
||||
LLMMessage(role="user", content="How are you?"),
|
||||
]
|
||||
system, chat = provider._messages_to_dicts(messages)
|
||||
assert system is None
|
||||
assert chat == snapshot(
|
||||
[
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_parse_response(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Hi there!"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage.input_tokens = 10
|
||||
mock_response.usage.output_tokens = 5
|
||||
|
||||
result = provider._parse_response(mock_response)
|
||||
assert result.content == snapshot("Hi there!")
|
||||
assert result.model == snapshot("claude-3-haiku-20240307")
|
||||
assert result.finish_reason == snapshot("end_turn")
|
||||
assert result.usage == snapshot({"input_tokens": 10, "output_tokens": 5})
|
||||
|
||||
def test_parse_response_empty_content(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = []
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = None
|
||||
|
||||
result = provider._parse_response(mock_response)
|
||||
assert result.content == snapshot("")
|
||||
|
||||
|
||||
class TestAnthropicProviderSync:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
return AnthropicProvider()
|
||||
|
||||
def test_sync_generate(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.messages.create.return_value = mock_response
|
||||
result = provider.sync_generate("Hello")
|
||||
assert result.content == snapshot("Response")
|
||||
|
||||
def test_sync_chat(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Chat response"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = None
|
||||
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.messages.create.return_value = mock_response
|
||||
result = provider.sync_chat(messages)
|
||||
assert result.content == snapshot("Chat response")
|
||||
|
||||
|
||||
class TestAnthropicProviderAsync:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
return AnthropicProvider()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Async response"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_async_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await provider.generate("Hello")
|
||||
assert result.content == snapshot("Async response")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Async chat"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = None
|
||||
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
|
||||
with patch.object(provider, "_get_async_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await provider.chat(messages)
|
||||
assert result.content == snapshot("Async chat")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_system_prompt(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "With system"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.stop_reason = "end_turn"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_async_client") as mock_client:
|
||||
mock_client.return_value.messages.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await provider.generate("Hello", system_prompt="Be brief")
|
||||
assert result.content == snapshot("With system")
|
||||
|
||||
|
||||
class TestAnthropicProviderErrors:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
return AnthropicProvider()
|
||||
|
||||
def test_handle_rate_limit_error(self, provider):
|
||||
import anthropic
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={}))
|
||||
|
||||
def test_handle_api_error(self, provider):
|
||||
import anthropic
|
||||
with pytest.raises(LLMProviderError):
|
||||
provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={}))
|
||||
|
||||
def test_handle_generic_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
provider._handle_error(Exception("something went wrong"))
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Tests for base LLM provider classes."""
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.base import (
|
||||
BaseLLMProvider,
|
||||
LLMMessage,
|
||||
LLMProviderError,
|
||||
LLMRateLimitError,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestLLMMessage:
|
||||
def test_create_message(self):
|
||||
msg = LLMMessage(role="user", content="hello")
|
||||
assert msg.role == snapshot("user")
|
||||
assert msg.content == snapshot("hello")
|
||||
|
||||
def test_system_message(self):
|
||||
msg = LLMMessage(role="system", content="You are helpful")
|
||||
assert msg.role == snapshot("system")
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
def test_minimal_response(self):
|
||||
resp = LLMResponse(content="Hello!")
|
||||
assert resp.content == snapshot("Hello!")
|
||||
assert resp.model is None
|
||||
assert resp.finish_reason is None
|
||||
assert resp.usage is None
|
||||
|
||||
def test_full_response(self):
|
||||
resp = LLMResponse(
|
||||
content="Hi there",
|
||||
model="gpt-4o",
|
||||
finish_reason="stop",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
assert resp.content == snapshot("Hi there")
|
||||
assert resp.model == snapshot("gpt-4o")
|
||||
assert resp.finish_reason == snapshot("stop")
|
||||
assert resp.usage == snapshot(
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
)
|
||||
|
||||
|
||||
class TestExceptions:
|
||||
def test_provider_error_is_exception(self):
|
||||
with pytest.raises(LLMProviderError):
|
||||
raise LLMProviderError("test error")
|
||||
|
||||
def test_rate_limit_error_is_provider_error(self):
|
||||
with pytest.raises(LLMProviderError):
|
||||
raise LLMRateLimitError("rate limited")
|
||||
|
||||
def test_rate_limit_error_specific_catch(self):
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
raise LLMRateLimitError("rate limited")
|
||||
|
||||
|
||||
class TestBaseLLMProvider:
|
||||
def test_cannot_instantiate_directly(self):
|
||||
with pytest.raises(TypeError):
|
||||
BaseLLMProvider(model="test") # type: ignore
|
||||
|
||||
def test_repr_format(self):
|
||||
# Create a concrete implementation for testing
|
||||
class ConcreteProvider(BaseLLMProvider):
|
||||
async def generate(self, prompt, **kwargs):
|
||||
return LLMResponse(content="")
|
||||
|
||||
async def chat(self, messages, **kwargs):
|
||||
return LLMResponse(content="")
|
||||
|
||||
def sync_generate(self, prompt, **kwargs):
|
||||
return LLMResponse(content="")
|
||||
|
||||
def sync_chat(self, messages, **kwargs):
|
||||
return LLMResponse(content="")
|
||||
|
||||
@classmethod
|
||||
def get_supported_models(cls):
|
||||
return ["test-model"]
|
||||
|
||||
provider = ConcreteProvider(model="test-model")
|
||||
assert repr(provider) == snapshot("ConcreteProvider(model='test-model')")
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Tests for LLM provider factory."""
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.factory import (
|
||||
create_provider,
|
||||
get_provider_class,
|
||||
list_providers,
|
||||
register_provider,
|
||||
)
|
||||
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse
|
||||
|
||||
|
||||
class TestListProviders:
|
||||
def test_includes_builtin_providers(self):
|
||||
providers = list_providers()
|
||||
assert "openai" in providers
|
||||
assert "anthropic" in providers
|
||||
|
||||
def test_returns_sorted_list(self):
|
||||
providers = list_providers()
|
||||
assert providers == sorted(providers)
|
||||
|
||||
|
||||
class TestGetProviderClass:
|
||||
def test_get_openai(self):
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
cls = get_provider_class("openai")
|
||||
assert cls is OpenAIProvider
|
||||
|
||||
def test_get_anthropic(self):
|
||||
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
|
||||
cls = get_provider_class("anthropic")
|
||||
assert cls is AnthropicProvider
|
||||
|
||||
def test_case_insensitive(self):
|
||||
cls1 = get_provider_class("OpenAI")
|
||||
cls2 = get_provider_class("OPENAI")
|
||||
cls3 = get_provider_class("openai")
|
||||
assert cls1 is cls2 is cls3
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(LLMProviderError) as exc:
|
||||
get_provider_class("unknown")
|
||||
assert "Unknown provider: unknown" in str(exc.value)
|
||||
assert "Available:" in str(exc.value)
|
||||
|
||||
|
||||
class TestRegisterProvider:
|
||||
def test_register_custom_provider(self):
|
||||
class CustomProvider(BaseLLMProvider):
|
||||
async def generate(self, prompt, **kwargs):
|
||||
return LLMResponse(content="custom")
|
||||
|
||||
async def chat(self, messages, **kwargs):
|
||||
return LLMResponse(content="custom")
|
||||
|
||||
def sync_generate(self, prompt, **kwargs):
|
||||
return LLMResponse(content="custom")
|
||||
|
||||
def sync_chat(self, messages, **kwargs):
|
||||
return LLMResponse(content="custom")
|
||||
|
||||
@classmethod
|
||||
def get_supported_models(cls):
|
||||
return ["custom-model"]
|
||||
|
||||
register_provider("custom", CustomProvider)
|
||||
assert "custom" in list_providers()
|
||||
assert get_provider_class("custom") is CustomProvider
|
||||
|
||||
|
||||
class TestCreateProvider:
|
||||
def test_create_openai_with_default_model(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
provider = create_provider("openai")
|
||||
assert provider.model == snapshot("gpt-4o-mini")
|
||||
|
||||
def test_create_openai_with_custom_model(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
provider = create_provider("openai", model="gpt-4o")
|
||||
assert provider.model == snapshot("gpt-4o")
|
||||
|
||||
def test_create_anthropic_with_default_model(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
provider = create_provider("anthropic")
|
||||
assert provider.model == snapshot("claude-3-haiku-20240307")
|
||||
|
||||
def test_create_anthropic_with_custom_model(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
provider = create_provider("anthropic", model="claude-3-5-sonnet-latest")
|
||||
assert provider.model == snapshot("claude-3-5-sonnet-latest")
|
||||
|
||||
def test_create_with_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
provider = create_provider("openai", api_key="direct-key")
|
||||
assert provider.api_key == snapshot("direct-key")
|
||||
|
||||
def test_create_unknown_provider_raises(self):
|
||||
with pytest.raises(LLMProviderError):
|
||||
create_provider("unknown")
|
||||
|
||||
def test_case_insensitive(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
provider = create_provider("OpenAI")
|
||||
assert provider.__class__.__name__ == snapshot("OpenAIProvider")
|
||||
@@ -0,0 +1,204 @@
|
||||
"""Tests for OpenAI provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from agentic_security.llm_providers.openai_provider import OpenAIProvider
|
||||
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
|
||||
|
||||
|
||||
class TestOpenAIProviderInit:
|
||||
def test_requires_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(LLMProviderError) as exc:
|
||||
OpenAIProvider()
|
||||
assert "OPENAI_API_KEY" in str(exc.value)
|
||||
|
||||
def test_accepts_api_key_directly(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
provider = OpenAIProvider(api_key="test-key")
|
||||
assert provider.api_key == snapshot("test-key")
|
||||
|
||||
def test_uses_env_api_key(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "env-key")
|
||||
provider = OpenAIProvider()
|
||||
assert provider.api_key == snapshot("env-key")
|
||||
|
||||
def test_default_model(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
provider = OpenAIProvider()
|
||||
assert provider.model == snapshot("gpt-4o-mini")
|
||||
|
||||
def test_custom_model(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
provider = OpenAIProvider(model="gpt-4o")
|
||||
assert provider.model == snapshot("gpt-4o")
|
||||
|
||||
def test_custom_base_url(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
provider = OpenAIProvider(base_url="https://custom.api.com")
|
||||
assert provider.base_url == snapshot("https://custom.api.com")
|
||||
|
||||
|
||||
class TestOpenAIProviderMethods:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
return OpenAIProvider()
|
||||
|
||||
def test_get_supported_models(self, provider):
|
||||
models = provider.get_supported_models()
|
||||
assert "gpt-4o" in models
|
||||
assert "gpt-4o-mini" in models
|
||||
assert "gpt-3.5-turbo" in models
|
||||
|
||||
def test_messages_to_dicts(self, provider):
|
||||
messages = [
|
||||
LLMMessage(role="system", content="Be helpful"),
|
||||
LLMMessage(role="user", content="Hello"),
|
||||
]
|
||||
result = provider._messages_to_dicts(messages)
|
||||
assert result == snapshot(
|
||||
[
|
||||
{"role": "system", "content": "Be helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_parse_response(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Hi there!"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
mock_response.usage.total_tokens = 15
|
||||
|
||||
result = provider._parse_response(mock_response)
|
||||
assert result.content == snapshot("Hi there!")
|
||||
assert result.model == snapshot("gpt-4o")
|
||||
assert result.finish_reason == snapshot("stop")
|
||||
assert result.usage == snapshot(
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
)
|
||||
|
||||
def test_parse_response_empty_content(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = None
|
||||
|
||||
result = provider._parse_response(mock_response)
|
||||
assert result.content == snapshot("")
|
||||
|
||||
|
||||
class TestOpenAIProviderSync:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
return OpenAIProvider()
|
||||
|
||||
def test_sync_generate(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create.return_value = mock_response
|
||||
result = provider.sync_generate("Hello")
|
||||
assert result.content == snapshot("Response")
|
||||
|
||||
def test_sync_chat(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Chat response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage = None
|
||||
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
|
||||
with patch.object(provider, "_get_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create.return_value = mock_response
|
||||
result = provider.sync_chat(messages)
|
||||
assert result.content == snapshot("Chat response")
|
||||
|
||||
|
||||
class TestOpenAIProviderAsync:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
return OpenAIProvider()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Async response"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_async_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await provider.generate("Hello")
|
||||
assert result.content == snapshot("Async response")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Async chat"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage = None
|
||||
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
|
||||
with patch.object(provider, "_get_async_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await provider.chat(messages)
|
||||
assert result.content == snapshot("Async chat")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_system_prompt(self, provider):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "With system"
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch.object(provider, "_get_async_client") as mock_client:
|
||||
mock_client.return_value.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await provider.generate("Hello", system_prompt="Be brief")
|
||||
assert result.content == snapshot("With system")
|
||||
|
||||
|
||||
class TestOpenAIProviderErrors:
|
||||
@pytest.fixture
|
||||
def provider(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
return OpenAIProvider()
|
||||
|
||||
def test_handle_rate_limit_error(self, provider):
|
||||
import openai
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={}))
|
||||
|
||||
def test_handle_generic_error(self, provider):
|
||||
with pytest.raises(LLMProviderError):
|
||||
provider._handle_error(Exception("something went wrong"))
|
||||
Reference in New Issue
Block a user