feat: US-004 - Unified LLM Provider Abstraction

Create unified provider abstraction layer for direct LLM integrations beyond
HTTP specs, inspired by FuzzyAI's comprehensive provider system.

- Add BaseLLMProvider abstract class with standard interface (generate, chat,
  sync_generate, sync_chat methods)
- Implement OpenAIProvider supporting chat completions API
- Implement AnthropicProvider supporting messages API
- Create provider factory for instantiation by name (create_provider,
  get_provider_class)
- Add 60 unit tests covering all provider implementations
This commit is contained in:
Alexander Myasoedov
2026-01-28 18:34:38 +02:00
parent 29decc5c4e
commit 41567925aa
10 changed files with 1067 additions and 0 deletions
@@ -0,0 +1,222 @@
"""Tests for Anthropic provider."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from inline_snapshot import snapshot
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
class TestAnthropicProviderInit:
def test_requires_api_key(self, monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
with pytest.raises(LLMProviderError) as exc:
AnthropicProvider()
assert "ANTHROPIC_API_KEY" in str(exc.value)
def test_accepts_api_key_directly(self, monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
provider = AnthropicProvider(api_key="test-key")
assert provider.api_key == snapshot("test-key")
def test_uses_env_api_key(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "env-key")
provider = AnthropicProvider()
assert provider.api_key == snapshot("env-key")
def test_default_model(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
provider = AnthropicProvider()
assert provider.model == snapshot("claude-3-haiku-20240307")
def test_custom_model(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
provider = AnthropicProvider(model="claude-3-5-sonnet-latest")
assert provider.model == snapshot("claude-3-5-sonnet-latest")
def test_custom_base_url(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
provider = AnthropicProvider(base_url="https://custom.api.com")
assert provider.base_url == snapshot("https://custom.api.com")
class TestAnthropicProviderMethods:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
return AnthropicProvider()
def test_get_supported_models(self, provider):
models = provider.get_supported_models()
assert "claude-3-haiku-20240307" in models
assert "claude-3-5-sonnet-latest" in models
def test_messages_to_dicts_simple(self, provider):
messages = [LLMMessage(role="user", content="Hello")]
system, chat = provider._messages_to_dicts(messages)
assert system is None
assert chat == snapshot([{"role": "user", "content": "Hello"}])
def test_messages_to_dicts_with_system(self, provider):
messages = [
LLMMessage(role="system", content="Be helpful"),
LLMMessage(role="user", content="Hello"),
]
system, chat = provider._messages_to_dicts(messages)
assert system == snapshot("Be helpful")
assert chat == snapshot([{"role": "user", "content": "Hello"}])
def test_messages_to_dicts_multi_turn(self, provider):
messages = [
LLMMessage(role="user", content="Hi"),
LLMMessage(role="assistant", content="Hello!"),
LLMMessage(role="user", content="How are you?"),
]
system, chat = provider._messages_to_dicts(messages)
assert system is None
assert chat == snapshot(
[
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "How are you?"},
]
)
def test_parse_response(self, provider):
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Hi there!"
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage.input_tokens = 10
mock_response.usage.output_tokens = 5
result = provider._parse_response(mock_response)
assert result.content == snapshot("Hi there!")
assert result.model == snapshot("claude-3-haiku-20240307")
assert result.finish_reason == snapshot("end_turn")
assert result.usage == snapshot({"input_tokens": 10, "output_tokens": 5})
def test_parse_response_empty_content(self, provider):
mock_response = MagicMock()
mock_response.content = []
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage = None
result = provider._parse_response(mock_response)
assert result.content == snapshot("")
class TestAnthropicProviderSync:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
return AnthropicProvider()
def test_sync_generate(self, provider):
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response"
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage = None
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.messages.create.return_value = mock_response
result = provider.sync_generate("Hello")
assert result.content == snapshot("Response")
def test_sync_chat(self, provider):
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Chat response"
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage = None
messages = [LLMMessage(role="user", content="Hi")]
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.messages.create.return_value = mock_response
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")
class TestAnthropicProviderAsync:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
return AnthropicProvider()
@pytest.mark.asyncio
async def test_generate(self, provider):
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Async response"
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage = None
with patch.object(provider, "_get_async_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(
return_value=mock_response
)
result = await provider.generate("Hello")
assert result.content == snapshot("Async response")
@pytest.mark.asyncio
async def test_chat(self, provider):
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Async chat"
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage = None
messages = [LLMMessage(role="user", content="Hi")]
with patch.object(provider, "_get_async_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(
return_value=mock_response
)
result = await provider.chat(messages)
assert result.content == snapshot("Async chat")
@pytest.mark.asyncio
async def test_generate_with_system_prompt(self, provider):
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "With system"
mock_response.model = "claude-3-haiku-20240307"
mock_response.stop_reason = "end_turn"
mock_response.usage = None
with patch.object(provider, "_get_async_client") as mock_client:
mock_client.return_value.messages.create = AsyncMock(
return_value=mock_response
)
result = await provider.generate("Hello", system_prompt="Be brief")
assert result.content == snapshot("With system")
class TestAnthropicProviderErrors:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
return AnthropicProvider()
def test_handle_rate_limit_error(self, provider):
import anthropic
with pytest.raises(LLMRateLimitError):
provider._handle_error(anthropic.RateLimitError("rate limited", response=MagicMock(), body={}))
def test_handle_api_error(self, provider):
import anthropic
with pytest.raises(LLMProviderError):
provider._handle_error(anthropic.APIError("api error", request=MagicMock(), body={}))
def test_handle_generic_error(self, provider):
with pytest.raises(LLMProviderError):
provider._handle_error(Exception("something went wrong"))
+88
View File
@@ -0,0 +1,88 @@
"""Tests for base LLM provider classes."""
import pytest
from inline_snapshot import snapshot
from agentic_security.llm_providers.base import (
BaseLLMProvider,
LLMMessage,
LLMProviderError,
LLMRateLimitError,
LLMResponse,
)
class TestLLMMessage:
def test_create_message(self):
msg = LLMMessage(role="user", content="hello")
assert msg.role == snapshot("user")
assert msg.content == snapshot("hello")
def test_system_message(self):
msg = LLMMessage(role="system", content="You are helpful")
assert msg.role == snapshot("system")
class TestLLMResponse:
def test_minimal_response(self):
resp = LLMResponse(content="Hello!")
assert resp.content == snapshot("Hello!")
assert resp.model is None
assert resp.finish_reason is None
assert resp.usage is None
def test_full_response(self):
resp = LLMResponse(
content="Hi there",
model="gpt-4o",
finish_reason="stop",
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
assert resp.content == snapshot("Hi there")
assert resp.model == snapshot("gpt-4o")
assert resp.finish_reason == snapshot("stop")
assert resp.usage == snapshot(
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
)
class TestExceptions:
def test_provider_error_is_exception(self):
with pytest.raises(LLMProviderError):
raise LLMProviderError("test error")
def test_rate_limit_error_is_provider_error(self):
with pytest.raises(LLMProviderError):
raise LLMRateLimitError("rate limited")
def test_rate_limit_error_specific_catch(self):
with pytest.raises(LLMRateLimitError):
raise LLMRateLimitError("rate limited")
class TestBaseLLMProvider:
def test_cannot_instantiate_directly(self):
with pytest.raises(TypeError):
BaseLLMProvider(model="test") # type: ignore
def test_repr_format(self):
# Create a concrete implementation for testing
class ConcreteProvider(BaseLLMProvider):
async def generate(self, prompt, **kwargs):
return LLMResponse(content="")
async def chat(self, messages, **kwargs):
return LLMResponse(content="")
def sync_generate(self, prompt, **kwargs):
return LLMResponse(content="")
def sync_chat(self, messages, **kwargs):
return LLMResponse(content="")
@classmethod
def get_supported_models(cls):
return ["test-model"]
provider = ConcreteProvider(model="test-model")
assert repr(provider) == snapshot("ConcreteProvider(model='test-model')")
+107
View File
@@ -0,0 +1,107 @@
"""Tests for LLM provider factory."""
import pytest
from inline_snapshot import snapshot
from agentic_security.llm_providers.factory import (
create_provider,
get_provider_class,
list_providers,
register_provider,
)
from agentic_security.llm_providers.base import BaseLLMProvider, LLMProviderError, LLMResponse
class TestListProviders:
def test_includes_builtin_providers(self):
providers = list_providers()
assert "openai" in providers
assert "anthropic" in providers
def test_returns_sorted_list(self):
providers = list_providers()
assert providers == sorted(providers)
class TestGetProviderClass:
def test_get_openai(self):
from agentic_security.llm_providers.openai_provider import OpenAIProvider
cls = get_provider_class("openai")
assert cls is OpenAIProvider
def test_get_anthropic(self):
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
cls = get_provider_class("anthropic")
assert cls is AnthropicProvider
def test_case_insensitive(self):
cls1 = get_provider_class("OpenAI")
cls2 = get_provider_class("OPENAI")
cls3 = get_provider_class("openai")
assert cls1 is cls2 is cls3
def test_unknown_provider_raises(self):
with pytest.raises(LLMProviderError) as exc:
get_provider_class("unknown")
assert "Unknown provider: unknown" in str(exc.value)
assert "Available:" in str(exc.value)
class TestRegisterProvider:
def test_register_custom_provider(self):
class CustomProvider(BaseLLMProvider):
async def generate(self, prompt, **kwargs):
return LLMResponse(content="custom")
async def chat(self, messages, **kwargs):
return LLMResponse(content="custom")
def sync_generate(self, prompt, **kwargs):
return LLMResponse(content="custom")
def sync_chat(self, messages, **kwargs):
return LLMResponse(content="custom")
@classmethod
def get_supported_models(cls):
return ["custom-model"]
register_provider("custom", CustomProvider)
assert "custom" in list_providers()
assert get_provider_class("custom") is CustomProvider
class TestCreateProvider:
def test_create_openai_with_default_model(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
provider = create_provider("openai")
assert provider.model == snapshot("gpt-4o-mini")
def test_create_openai_with_custom_model(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
provider = create_provider("openai", model="gpt-4o")
assert provider.model == snapshot("gpt-4o")
def test_create_anthropic_with_default_model(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
provider = create_provider("anthropic")
assert provider.model == snapshot("claude-3-haiku-20240307")
def test_create_anthropic_with_custom_model(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
provider = create_provider("anthropic", model="claude-3-5-sonnet-latest")
assert provider.model == snapshot("claude-3-5-sonnet-latest")
def test_create_with_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
provider = create_provider("openai", api_key="direct-key")
assert provider.api_key == snapshot("direct-key")
def test_create_unknown_provider_raises(self):
with pytest.raises(LLMProviderError):
create_provider("unknown")
def test_case_insensitive(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
provider = create_provider("OpenAI")
assert provider.__class__.__name__ == snapshot("OpenAIProvider")
@@ -0,0 +1,204 @@
"""Tests for OpenAI provider."""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from inline_snapshot import snapshot
from agentic_security.llm_providers.openai_provider import OpenAIProvider
from agentic_security.llm_providers.base import LLMMessage, LLMProviderError, LLMRateLimitError
class TestOpenAIProviderInit:
def test_requires_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(LLMProviderError) as exc:
OpenAIProvider()
assert "OPENAI_API_KEY" in str(exc.value)
def test_accepts_api_key_directly(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
provider = OpenAIProvider(api_key="test-key")
assert provider.api_key == snapshot("test-key")
def test_uses_env_api_key(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "env-key")
provider = OpenAIProvider()
assert provider.api_key == snapshot("env-key")
def test_default_model(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
provider = OpenAIProvider()
assert provider.model == snapshot("gpt-4o-mini")
def test_custom_model(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
provider = OpenAIProvider(model="gpt-4o")
assert provider.model == snapshot("gpt-4o")
def test_custom_base_url(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
provider = OpenAIProvider(base_url="https://custom.api.com")
assert provider.base_url == snapshot("https://custom.api.com")
class TestOpenAIProviderMethods:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
return OpenAIProvider()
def test_get_supported_models(self, provider):
models = provider.get_supported_models()
assert "gpt-4o" in models
assert "gpt-4o-mini" in models
assert "gpt-3.5-turbo" in models
def test_messages_to_dicts(self, provider):
messages = [
LLMMessage(role="system", content="Be helpful"),
LLMMessage(role="user", content="Hello"),
]
result = provider._messages_to_dicts(messages)
assert result == snapshot(
[
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "Hello"},
]
)
def test_parse_response(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hi there!"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
mock_response.usage.total_tokens = 15
result = provider._parse_response(mock_response)
assert result.content == snapshot("Hi there!")
assert result.model == snapshot("gpt-4o")
assert result.finish_reason == snapshot("stop")
assert result.usage == snapshot(
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
)
def test_parse_response_empty_content(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = None
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o"
mock_response.usage = None
result = provider._parse_response(mock_response)
assert result.content == snapshot("")
class TestOpenAIProviderSync:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
return OpenAIProvider()
def test_sync_generate(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage = None
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response
result = provider.sync_generate("Hello")
assert result.content == snapshot("Response")
def test_sync_chat(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Chat response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage = None
messages = [LLMMessage(role="user", content="Hi")]
with patch.object(provider, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")
class TestOpenAIProviderAsync:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
return OpenAIProvider()
@pytest.mark.asyncio
async def test_generate(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Async response"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage = None
with patch.object(provider, "_get_async_client") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await provider.generate("Hello")
assert result.content == snapshot("Async response")
@pytest.mark.asyncio
async def test_chat(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Async chat"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage = None
messages = [LLMMessage(role="user", content="Hi")]
with patch.object(provider, "_get_async_client") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await provider.chat(messages)
assert result.content == snapshot("Async chat")
@pytest.mark.asyncio
async def test_generate_with_system_prompt(self, provider):
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "With system"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "gpt-4o-mini"
mock_response.usage = None
with patch.object(provider, "_get_async_client") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await provider.generate("Hello", system_prompt="Be brief")
assert result.content == snapshot("With system")
class TestOpenAIProviderErrors:
@pytest.fixture
def provider(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
return OpenAIProvider()
def test_handle_rate_limit_error(self, provider):
import openai
with pytest.raises(LLMRateLimitError):
provider._handle_error(openai.RateLimitError("rate limited", response=MagicMock(), body={}))
def test_handle_generic_error(self, provider):
with pytest.raises(LLMProviderError):
provider._handle_error(Exception("something went wrong"))