From 41567925aa602d4270e2662fa15ab3ff796623dc Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Wed, 28 Jan 2026 18:34:38 +0200 Subject: [PATCH] feat: US-004 - Unified LLM Provider Abstraction Create unified provider abstraction layer for direct LLM integrations beyond HTTP specs, inspired by FuzzyAI's comprehensive provider system. - Add BaseLLMProvider abstract class with standard interface (generate, chat, sync_generate, sync_chat methods) - Implement OpenAIProvider supporting chat completions API - Implement AnthropicProvider supporting messages API - Create provider factory for instantiation by name (create_provider, get_provider_class) - Add 60 unit tests covering all provider implementations --- agentic_security/llm_providers/__init__.py | 22 ++ .../llm_providers/anthropic_provider.py | 154 ++++++++++++ agentic_security/llm_providers/base.py | 78 ++++++ agentic_security/llm_providers/factory.py | 66 ++++++ .../llm_providers/openai_provider.py | 126 ++++++++++ tests/unit/llm_providers/__init__.py | 0 .../llm_providers/test_anthropic_provider.py | 222 ++++++++++++++++++ tests/unit/llm_providers/test_base.py | 88 +++++++ tests/unit/llm_providers/test_factory.py | 107 +++++++++ .../llm_providers/test_openai_provider.py | 204 ++++++++++++++++ 10 files changed, 1067 insertions(+) create mode 100644 agentic_security/llm_providers/__init__.py create mode 100644 agentic_security/llm_providers/anthropic_provider.py create mode 100644 agentic_security/llm_providers/base.py create mode 100644 agentic_security/llm_providers/factory.py create mode 100644 agentic_security/llm_providers/openai_provider.py create mode 100644 tests/unit/llm_providers/__init__.py create mode 100644 tests/unit/llm_providers/test_anthropic_provider.py create mode 100644 tests/unit/llm_providers/test_base.py create mode 100644 tests/unit/llm_providers/test_factory.py create mode 100644 tests/unit/llm_providers/test_openai_provider.py diff --git a/agentic_security/llm_providers/__init__.py b/agentic_security/llm_providers/__init__.py new file mode 100644 index 0000000..dbe17d5 --- /dev/null +++ b/agentic_security/llm_providers/__init__.py @@ -0,0 +1,22 @@ +from agentic_security.llm_providers.base import ( + BaseLLMProvider, + LLMMessage, + LLMResponse, + LLMProviderError, + LLMRateLimitError, +) +from agentic_security.llm_providers.openai_provider import OpenAIProvider +from agentic_security.llm_providers.anthropic_provider import AnthropicProvider +from agentic_security.llm_providers.factory import create_provider, get_provider_class + +__all__ = [ + "BaseLLMProvider", + "LLMMessage", + "LLMResponse", + "LLMProviderError", + "LLMRateLimitError", + "OpenAIProvider", + "AnthropicProvider", + "create_provider", + "get_provider_class", +] diff --git a/agentic_security/llm_providers/anthropic_provider.py b/agentic_security/llm_providers/anthropic_provider.py new file mode 100644 index 0000000..a5b64b0 --- /dev/null +++ b/agentic_security/llm_providers/anthropic_provider.py @@ -0,0 +1,154 @@ +"""Anthropic LLM provider implementation.""" + +import os +from typing import Any + +from agentic_security.llm_providers.base import ( + BaseLLMProvider, + LLMMessage, + LLMProviderError, + LLMRateLimitError, + LLMResponse, +) + + +class AnthropicProvider(BaseLLMProvider): + """Anthropic API provider supporting messages API.""" + + DEFAULT_MODEL = "claude-3-haiku-20240307" + API_KEY_ENV = "ANTHROPIC_API_KEY" + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + base_url: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(model, **kwargs) + self.api_key = api_key or os.environ.get(self.API_KEY_ENV) + if not self.api_key: + raise LLMProviderError(f"{self.API_KEY_ENV} not set") + self.base_url = base_url + self._client: Any = None + self._async_client: Any = None + + def _get_client(self) -> Any: + if self._client is None: + import anthropic + kwargs: dict[str, Any] = {"api_key": self.api_key} + if self.base_url: + kwargs["base_url"] = self.base_url + self._client = anthropic.Anthropic(**kwargs) + return self._client + + def _get_async_client(self) -> Any: + if self._async_client is None: + import anthropic + kwargs: dict[str, Any] = {"api_key": self.api_key} + if self.base_url: + kwargs["base_url"] = self.base_url + self._async_client = anthropic.AsyncAnthropic(**kwargs) + return self._async_client + + @classmethod + def get_supported_models(cls) -> list[str]: + return [ + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-latest", + "claude-3-5-haiku-latest", + "claude-3-5-sonnet-latest", + ] + + def _messages_to_dicts( + self, messages: list[LLMMessage] + ) -> tuple[str | None, list[dict[str, str]]]: + """Extract system prompt and convert messages to Anthropic format.""" + system_prompt = None + chat_messages = [] + for m in messages: + if m.role == "system": + system_prompt = m.content + else: + chat_messages.append({"role": m.role, "content": m.content}) + return system_prompt, chat_messages + + def _parse_response(self, response: Any) -> LLMResponse: + content = "" + if response.content: + block = response.content[0] + if hasattr(block, "text"): + content = block.text + usage = None + if response.usage: + usage = { + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens, + } + return LLMResponse( + content=content, + model=response.model, + finish_reason=response.stop_reason, + usage=usage, + ) + + def _handle_error(self, e: Exception) -> None: + import anthropic + if isinstance(e, anthropic.RateLimitError): + raise LLMRateLimitError(str(e)) from e + if isinstance(e, anthropic.APIError): + raise LLMProviderError(str(e)) from e + raise LLMProviderError(str(e)) from e + + async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + messages = [LLMMessage(role="user", content=prompt)] + if system_prompt := kwargs.pop("system_prompt", None): + messages.insert(0, LLMMessage(role="system", content=system_prompt)) + return await self.chat(messages, **kwargs) + + async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + client = self._get_async_client() + system_prompt, chat_messages = self._messages_to_dicts(messages) + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": chat_messages, + "max_tokens": kwargs.pop("max_tokens", 1024), + } + if system_prompt: + create_kwargs["system"] = system_prompt + create_kwargs.update(kwargs) + try: + response = await client.messages.create(**create_kwargs) + return self._parse_response(response) + except Exception as e: + self._handle_error(e) + raise # unreachable, but satisfies type checker + + def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + messages = [LLMMessage(role="user", content=prompt)] + if system_prompt := kwargs.pop("system_prompt", None): + messages.insert(0, LLMMessage(role="system", content=system_prompt)) + return self.sync_chat(messages, **kwargs) + + def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + client = self._get_client() + system_prompt, chat_messages = self._messages_to_dicts(messages) + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": chat_messages, + "max_tokens": kwargs.pop("max_tokens", 1024), + } + if system_prompt: + create_kwargs["system"] = system_prompt + create_kwargs.update(kwargs) + try: + response = client.messages.create(**create_kwargs) + return self._parse_response(response) + except Exception as e: + self._handle_error(e) + raise # unreachable, but satisfies type checker + + async def close(self) -> None: + if self._async_client: + await self._async_client.close() diff --git a/agentic_security/llm_providers/base.py b/agentic_security/llm_providers/base.py new file mode 100644 index 0000000..1992078 --- /dev/null +++ b/agentic_security/llm_providers/base.py @@ -0,0 +1,78 @@ +"""Base LLM provider abstraction for unified API access. + +Inspired by FuzzyAI's provider architecture, providing a simple interface +for both sync and async LLM interactions. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + + +class LLMProviderError(Exception): + """Base exception for LLM provider errors.""" + + +class LLMRateLimitError(LLMProviderError): + """Raised when rate limit is exceeded.""" + + +@dataclass +class LLMMessage: + """A message in a chat conversation.""" + role: str # "system", "user", or "assistant" + content: str + + +@dataclass +class LLMResponse: + """Response from an LLM provider.""" + content: str + model: str | None = None + finish_reason: str | None = None + usage: dict[str, int] | None = None + + +class BaseLLMProvider(ABC): + """Abstract base class for LLM providers. + + Subclasses must implement generate() and chat() methods for both + sync and async variants. + """ + + def __init__(self, model: str, **kwargs: Any) -> None: + self.model = model + self._extra = kwargs + + @abstractmethod + async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + """Generate a response from a single prompt.""" + ... + + @abstractmethod + async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + """Generate a response from a chat conversation.""" + ... + + @abstractmethod + def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + """Synchronous version of generate().""" + ... + + @abstractmethod + def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + """Synchronous version of chat().""" + ... + + @classmethod + @abstractmethod + def get_supported_models(cls) -> list[str]: + """Return list of supported model names.""" + ... + + async def close(self) -> None: + """Close any open connections. Override if cleanup is needed.""" + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(model={self.model!r})" diff --git a/agentic_security/llm_providers/factory.py b/agentic_security/llm_providers/factory.py new file mode 100644 index 0000000..2e4edb3 --- /dev/null +++ b/agentic_security/llm_providers/factory.py @@ -0,0 +1,66 @@ +"""Factory for creating LLM provider instances.""" + +from typing import Any + +from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError + +# Provider registry mapping name to class +_PROVIDERS: dict[str, type[BaseLLMProvider]] = {} + + +def _ensure_registered() -> None: + """Lazy registration of built-in providers.""" + if _PROVIDERS: + return + from agentic_security.llm_providers.openai_provider import OpenAIProvider + from agentic_security.llm_providers.anthropic_provider import AnthropicProvider + _PROVIDERS["openai"] = OpenAIProvider + _PROVIDERS["anthropic"] = AnthropicProvider + + +def register_provider(name: str, provider_class: type[BaseLLMProvider]) -> None: + """Register a custom provider class.""" + _ensure_registered() + _PROVIDERS[name.lower()] = provider_class + + +def get_provider_class(name: str) -> type[BaseLLMProvider]: + """Get provider class by name.""" + _ensure_registered() + name_lower = name.lower() + if name_lower not in _PROVIDERS: + available = ", ".join(sorted(_PROVIDERS.keys())) + raise LLMProviderError(f"Unknown provider: {name}. Available: {available}") + return _PROVIDERS[name_lower] + + +def list_providers() -> list[str]: + """List all available provider names.""" + _ensure_registered() + return sorted(_PROVIDERS.keys()) + + +def create_provider( + name: str, + model: str | None = None, + **kwargs: Any, +) -> BaseLLMProvider: + """Create a provider instance by name. + + Args: + name: Provider name ("openai", "anthropic", etc.) + model: Model name. If None, uses provider's default. + **kwargs: Additional arguments passed to provider constructor. + + Returns: + Configured provider instance. + + Raises: + LLMProviderError: If provider name is unknown. + """ + provider_class = get_provider_class(name) + if model is None: + model = getattr(provider_class, "DEFAULT_MODEL", None) + if model is None: + raise LLMProviderError(f"No model specified and {name} has no default") + return provider_class(model=model, **kwargs) diff --git a/agentic_security/llm_providers/openai_provider.py b/agentic_security/llm_providers/openai_provider.py new file mode 100644 index 0000000..6d0add4 --- /dev/null +++ b/agentic_security/llm_providers/openai_provider.py @@ -0,0 +1,126 @@ +"""OpenAI LLM provider implementation.""" + +import os +from typing import Any + +from agentic_security.llm_providers.base import ( + BaseLLMProvider, + LLMMessage, + LLMProviderError, + LLMRateLimitError, + LLMResponse, +) + + +class OpenAIProvider(BaseLLMProvider): + """OpenAI API provider supporting chat completions.""" + + DEFAULT_MODEL = "gpt-4o-mini" + API_KEY_ENV = "OPENAI_API_KEY" + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + base_url: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(model, **kwargs) + self.api_key = api_key or os.environ.get(self.API_KEY_ENV) + if not self.api_key: + raise LLMProviderError(f"{self.API_KEY_ENV} not set") + self.base_url = base_url + self._client: Any = None + self._async_client: Any = None + + def _get_client(self) -> Any: + if self._client is None: + import openai + self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + return self._client + + def _get_async_client(self) -> Any: + if self._async_client is None: + import openai + self._async_client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + return self._async_client + + @classmethod + def get_supported_models(cls) -> list[str]: + return [ + "gpt-3.5-turbo", + "gpt-4", + "gpt-4-turbo", + "gpt-4o", + "gpt-4o-mini", + "o1-mini", + "o1-preview", + "o3-mini", + ] + + def _messages_to_dicts(self, messages: list[LLMMessage]) -> list[dict[str, str]]: + return [{"role": m.role, "content": m.content} for m in messages] + + def _parse_response(self, response: Any) -> LLMResponse: + choice = response.choices[0] + usage = None + if response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + return LLMResponse( + content=choice.message.content or "", + model=response.model, + finish_reason=choice.finish_reason, + usage=usage, + ) + + def _handle_error(self, e: Exception) -> None: + import openai + if isinstance(e, openai.RateLimitError): + raise LLMRateLimitError(str(e)) from e + raise LLMProviderError(str(e)) from e + + async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + messages = [LLMMessage(role="user", content=prompt)] + if system_prompt := kwargs.pop("system_prompt", None): + messages.insert(0, LLMMessage(role="system", content=system_prompt)) + return await self.chat(messages, **kwargs) + + async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + client = self._get_async_client() + try: + response = await client.chat.completions.create( + model=self.model, + messages=self._messages_to_dicts(messages), + **kwargs, + ) + return self._parse_response(response) + except Exception as e: + self._handle_error(e) + raise # unreachable, but satisfies type checker + + def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse: + messages = [LLMMessage(role="user", content=prompt)] + if system_prompt := kwargs.pop("system_prompt", None): + messages.insert(0, LLMMessage(role="system", content=system_prompt)) + return self.sync_chat(messages, **kwargs) + + def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse: + client = self._get_client() + try: + response = client.chat.completions.create( + model=self.model, + messages=self._messages_to_dicts(messages), + **kwargs, + ) + return self._parse_response(response) + except Exception as e: + self._handle_error(e) + raise # unreachable, but satisfies type checker + + async def close(self) -> None: + if self._async_client: + await self._async_client.close() diff --git a/tests/unit/llm_providers/__init__.py b/tests/unit/llm_providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/llm_providers/test_anthropic_provider.py b/tests/unit/llm_providers/test_anthropic_provider.py new file mode 100644 index 0000000..c6f6725 --- /dev/null +++ b/tests/unit/llm_providers/test_anthropic_provider.py @@ -0,0 +1,222 @@ +"""Tests for Anthropic provider.""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from inline_snapshot import snapshot + +from agentic_security.llm_providers.anthropic_provider import AnthropicProvider +from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError + + +class TestAnthropicProviderInit: + def test_requires_api_key(self, monkeypatch): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + with pytest.raises(LLMProviderError) as exc: + AnthropicProvider() + assert "ANTHROPIC_API_KEY" in str(exc.value) + + def test_accepts_api_key_directly(self, monkeypatch): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + provider = AnthropicProvider(api_key="test-key") + assert provider.api_key == snapshot("test-key") + + def test_uses_env_api_key(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "env-key") + provider = AnthropicProvider() + assert provider.api_key == snapshot("env-key") + + def test_default_model(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + provider = AnthropicProvider() + assert provider.model == snapshot("claude-3-haiku-20240307") + + def test_custom_model(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + provider = AnthropicProvider(model="claude-3-5-sonnet-latest") + assert provider.model == snapshot("claude-3-5-sonnet-latest") + + def test_custom_base_url(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + provider = AnthropicProvider(base_url="https://custom.api.com") + assert provider.base_url == snapshot("https://custom.api.com") + + +class TestAnthropicProviderMethods: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + return AnthropicProvider() + + def test_get_supported_models(self, provider): + models = provider.get_supported_models() + assert "claude-3-haiku-20240307" in models + assert "claude-3-5-sonnet-latest" in models + + def test_messages_to_dicts_simple(self, provider): + messages = [LLMMessage(role="user", content="Hello")] + system, chat = provider._messages_to_dicts(messages) + assert system is None + assert chat == snapshot([{"role": "user", "content": "Hello"}]) + + def test_messages_to_dicts_with_system(self, provider): + messages = [ + LLMMessage(role="system", content="Be helpful"), + LLMMessage(role="user", content="Hello"), + ] + system, chat = provider._messages_to_dicts(messages) + assert system == snapshot("Be helpful") + assert chat == snapshot([{"role": "user", "content": "Hello"}]) + + def test_messages_to_dicts_multi_turn(self, provider): + messages = [ + LLMMessage(role="user", content="Hi"), + LLMMessage(role="assistant", content="Hello!"), + LLMMessage(role="user", content="How are you?"), + ] + system, chat = provider._messages_to_dicts(messages) + assert system is None + assert chat == snapshot( + [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + ] + ) + + def test_parse_response(self, provider): + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Hi there!" + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + + result = provider._parse_response(mock_response) + assert result.content == snapshot("Hi there!") + assert result.model == snapshot("claude-3-haiku-20240307") + assert result.finish_reason == snapshot("end_turn") + assert result.usage == snapshot({"input_tokens": 10, "output_tokens": 5}) + + def test_parse_response_empty_content(self, provider): + mock_response = MagicMock() + mock_response.content = [] + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage = None + + result = provider._parse_response(mock_response) + assert result.content == snapshot("") + + +class TestAnthropicProviderSync: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + return AnthropicProvider() + + def test_sync_generate(self, provider): + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage = None + + with patch.object(provider, "_get_client") as mock_client: + mock_client.return_value.messages.create.return_value = mock_response + result = provider.sync_generate("Hello") + assert result.content == snapshot("Response") + + def test_sync_chat(self, provider): + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Chat response" + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage = None + + messages = [LLMMessage(role="user", content="Hi")] + + with patch.object(provider, "_get_client") as mock_client: + mock_client.return_value.messages.create.return_value = mock_response + result = provider.sync_chat(messages) + assert result.content == snapshot("Chat response") + + +class TestAnthropicProviderAsync: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + return AnthropicProvider() + + @pytest.mark.asyncio + async def test_generate(self, provider): + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Async response" + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage = None + + with patch.object(provider, "_get_async_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock( + return_value=mock_response + ) + result = await provider.generate("Hello") + assert result.content == snapshot("Async response") + + @pytest.mark.asyncio + async def test_chat(self, provider): + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Async chat" + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage = None + + messages = [LLMMessage(role="user", content="Hi")] + + with patch.object(provider, "_get_async_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock( + return_value=mock_response + ) + result = await provider.chat(messages) + assert result.content == snapshot("Async chat") + + @pytest.mark.asyncio + async def test_generate_with_system_prompt(self, provider): + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "With system" + mock_response.model = "claude-3-haiku-20240307" + mock_response.stop_reason = "end_turn" + mock_response.usage = None + + with patch.object(provider, "_get_async_client") as mock_client: + mock_client.return_value.messages.create = AsyncMock( + return_value=mock_response + ) + result = await provider.generate("Hello", system_prompt="Be brief") + assert result.content == snapshot("With system") + + +class TestAnthropicProviderErrors: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + return AnthropicProvider() + + def test_handle_rate_limit_error(self, provider): + import anthropic + with pytest.raises(LLMRateLimitError): + provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={})) + + def test_handle_api_error(self, provider): + import anthropic + with pytest.raises(LLMProviderError): + provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={})) + + def test_handle_generic_error(self, provider): + with pytest.raises(LLMProviderError): + provider._handle_error(Exception("something went wrong")) diff --git a/tests/unit/llm_providers/test_base.py b/tests/unit/llm_providers/test_base.py new file mode 100644 index 0000000..83663d1 --- /dev/null +++ b/tests/unit/llm_providers/test_base.py @@ -0,0 +1,88 @@ +"""Tests for base LLM provider classes.""" + +import pytest +from inline_snapshot import snapshot + +from agentic_security.llm_providers.base import ( + BaseLLMProvider, + LLMMessage, + LLMProviderError, + LLMRateLimitError, + LLMResponse, +) + + +class TestLLMMessage: + def test_create_message(self): + msg = LLMMessage(role="user", content="hello") + assert msg.role == snapshot("user") + assert msg.content == snapshot("hello") + + def test_system_message(self): + msg = LLMMessage(role="system", content="You are helpful") + assert msg.role == snapshot("system") + + +class TestLLMResponse: + def test_minimal_response(self): + resp = LLMResponse(content="Hello!") + assert resp.content == snapshot("Hello!") + assert resp.model is None + assert resp.finish_reason is None + assert resp.usage is None + + def test_full_response(self): + resp = LLMResponse( + content="Hi there", + model="gpt-4o", + finish_reason="stop", + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + ) + assert resp.content == snapshot("Hi there") + assert resp.model == snapshot("gpt-4o") + assert resp.finish_reason == snapshot("stop") + assert resp.usage == snapshot( + {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + + +class TestExceptions: + def test_provider_error_is_exception(self): + with pytest.raises(LLMProviderError): + raise LLMProviderError("test error") + + def test_rate_limit_error_is_provider_error(self): + with pytest.raises(LLMProviderError): + raise LLMRateLimitError("rate limited") + + def test_rate_limit_error_specific_catch(self): + with pytest.raises(LLMRateLimitError): + raise LLMRateLimitError("rate limited") + + +class TestBaseLLMProvider: + def test_cannot_instantiate_directly(self): + with pytest.raises(TypeError): + BaseLLMProvider(model="test") # type: ignore + + def test_repr_format(self): + # Create a concrete implementation for testing + class ConcreteProvider(BaseLLMProvider): + async def generate(self, prompt, **kwargs): + return LLMResponse(content="") + + async def chat(self, messages, **kwargs): + return LLMResponse(content="") + + def sync_generate(self, prompt, **kwargs): + return LLMResponse(content="") + + def sync_chat(self, messages, **kwargs): + return LLMResponse(content="") + + @classmethod + def get_supported_models(cls): + return ["test-model"] + + provider = ConcreteProvider(model="test-model") + assert repr(provider) == snapshot("ConcreteProvider(model='test-model')") diff --git a/tests/unit/llm_providers/test_factory.py b/tests/unit/llm_providers/test_factory.py new file mode 100644 index 0000000..dbcd633 --- /dev/null +++ b/tests/unit/llm_providers/test_factory.py @@ -0,0 +1,107 @@ +"""Tests for LLM provider factory.""" + +import pytest +from inline_snapshot import snapshot + +from agentic_security.llm_providers.factory import ( + create_provider, + get_provider_class, + list_providers, + register_provider, +) +from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse + + +class TestListProviders: + def test_includes_builtin_providers(self): + providers = list_providers() + assert "openai" in providers + assert "anthropic" in providers + + def test_returns_sorted_list(self): + providers = list_providers() + assert providers == sorted(providers) + + +class TestGetProviderClass: + def test_get_openai(self): + from agentic_security.llm_providers.openai_provider import OpenAIProvider + cls = get_provider_class("openai") + assert cls is OpenAIProvider + + def test_get_anthropic(self): + from agentic_security.llm_providers.anthropic_provider import AnthropicProvider + cls = get_provider_class("anthropic") + assert cls is AnthropicProvider + + def test_case_insensitive(self): + cls1 = get_provider_class("OpenAI") + cls2 = get_provider_class("OPENAI") + cls3 = get_provider_class("openai") + assert cls1 is cls2 is cls3 + + def test_unknown_provider_raises(self): + with pytest.raises(LLMProviderError) as exc: + get_provider_class("unknown") + assert "Unknown provider: unknown" in str(exc.value) + assert "Available:" in str(exc.value) + + +class TestRegisterProvider: + def test_register_custom_provider(self): + class CustomProvider(BaseLLMProvider): + async def generate(self, prompt, **kwargs): + return LLMResponse(content="custom") + + async def chat(self, messages, **kwargs): + return LLMResponse(content="custom") + + def sync_generate(self, prompt, **kwargs): + return LLMResponse(content="custom") + + def sync_chat(self, messages, **kwargs): + return LLMResponse(content="custom") + + @classmethod + def get_supported_models(cls): + return ["custom-model"] + + register_provider("custom", CustomProvider) + assert "custom" in list_providers() + assert get_provider_class("custom") is CustomProvider + + +class TestCreateProvider: + def test_create_openai_with_default_model(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + provider = create_provider("openai") + assert provider.model == snapshot("gpt-4o-mini") + + def test_create_openai_with_custom_model(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + provider = create_provider("openai", model="gpt-4o") + assert provider.model == snapshot("gpt-4o") + + def test_create_anthropic_with_default_model(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + provider = create_provider("anthropic") + assert provider.model == snapshot("claude-3-haiku-20240307") + + def test_create_anthropic_with_custom_model(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + provider = create_provider("anthropic", model="claude-3-5-sonnet-latest") + assert provider.model == snapshot("claude-3-5-sonnet-latest") + + def test_create_with_api_key(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + provider = create_provider("openai", api_key="direct-key") + assert provider.api_key == snapshot("direct-key") + + def test_create_unknown_provider_raises(self): + with pytest.raises(LLMProviderError): + create_provider("unknown") + + def test_case_insensitive(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + provider = create_provider("OpenAI") + assert provider.__class__.__name__ == snapshot("OpenAIProvider") diff --git a/tests/unit/llm_providers/test_openai_provider.py b/tests/unit/llm_providers/test_openai_provider.py new file mode 100644 index 0000000..85b90d0 --- /dev/null +++ b/tests/unit/llm_providers/test_openai_provider.py @@ -0,0 +1,204 @@ +"""Tests for OpenAI provider.""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from inline_snapshot import snapshot + +from agentic_security.llm_providers.openai_provider import OpenAIProvider +from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError + + +class TestOpenAIProviderInit: + def test_requires_api_key(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with pytest.raises(LLMProviderError) as exc: + OpenAIProvider() + assert "OPENAI_API_KEY" in str(exc.value) + + def test_accepts_api_key_directly(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + provider = OpenAIProvider(api_key="test-key") + assert provider.api_key == snapshot("test-key") + + def test_uses_env_api_key(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + provider = OpenAIProvider() + assert provider.api_key == snapshot("env-key") + + def test_default_model(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + provider = OpenAIProvider() + assert provider.model == snapshot("gpt-4o-mini") + + def test_custom_model(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + provider = OpenAIProvider(model="gpt-4o") + assert provider.model == snapshot("gpt-4o") + + def test_custom_base_url(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + provider = OpenAIProvider(base_url="https://custom.api.com") + assert provider.base_url == snapshot("https://custom.api.com") + + +class TestOpenAIProviderMethods: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + return OpenAIProvider() + + def test_get_supported_models(self, provider): + models = provider.get_supported_models() + assert "gpt-4o" in models + assert "gpt-4o-mini" in models + assert "gpt-3.5-turbo" in models + + def test_messages_to_dicts(self, provider): + messages = [ + LLMMessage(role="system", content="Be helpful"), + LLMMessage(role="user", content="Hello"), + ] + result = provider._messages_to_dicts(messages) + assert result == snapshot( + [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Hello"}, + ] + ) + + def test_parse_response(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hi there!" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + result = provider._parse_response(mock_response) + assert result.content == snapshot("Hi there!") + assert result.model == snapshot("gpt-4o") + assert result.finish_reason == snapshot("stop") + assert result.usage == snapshot( + {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + + def test_parse_response_empty_content(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = None + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o" + mock_response.usage = None + + result = provider._parse_response(mock_response) + assert result.content == snapshot("") + + +class TestOpenAIProviderSync: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + return OpenAIProvider() + + def test_sync_generate(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o-mini" + mock_response.usage = None + + with patch.object(provider, "_get_client") as mock_client: + mock_client.return_value.chat.completions.create.return_value = mock_response + result = provider.sync_generate("Hello") + assert result.content == snapshot("Response") + + def test_sync_chat(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Chat response" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o-mini" + mock_response.usage = None + + messages = [LLMMessage(role="user", content="Hi")] + + with patch.object(provider, "_get_client") as mock_client: + mock_client.return_value.chat.completions.create.return_value = mock_response + result = provider.sync_chat(messages) + assert result.content == snapshot("Chat response") + + +class TestOpenAIProviderAsync: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + return OpenAIProvider() + + @pytest.mark.asyncio + async def test_generate(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Async response" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o-mini" + mock_response.usage = None + + with patch.object(provider, "_get_async_client") as mock_client: + mock_client.return_value.chat.completions.create = AsyncMock( + return_value=mock_response + ) + result = await provider.generate("Hello") + assert result.content == snapshot("Async response") + + @pytest.mark.asyncio + async def test_chat(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Async chat" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o-mini" + mock_response.usage = None + + messages = [LLMMessage(role="user", content="Hi")] + + with patch.object(provider, "_get_async_client") as mock_client: + mock_client.return_value.chat.completions.create = AsyncMock( + return_value=mock_response + ) + result = await provider.chat(messages) + assert result.content == snapshot("Async chat") + + @pytest.mark.asyncio + async def test_generate_with_system_prompt(self, provider): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "With system" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "gpt-4o-mini" + mock_response.usage = None + + with patch.object(provider, "_get_async_client") as mock_client: + mock_client.return_value.chat.completions.create = AsyncMock( + return_value=mock_response + ) + result = await provider.generate("Hello", system_prompt="Be brief") + assert result.content == snapshot("With system") + + +class TestOpenAIProviderErrors: + @pytest.fixture + def provider(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + return OpenAIProvider() + + def test_handle_rate_limit_error(self, provider): + import openai + with pytest.raises(LLMRateLimitError): + provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={})) + + def test_handle_generic_error(self, provider): + with pytest.raises(LLMProviderError): + provider._handle_error(Exception("something went wrong"))