This commit is contained in:
Alexander Myasoedov
2026-06-03 15:05:59 +03:00
parent 67cedfb116
commit a193ef9c2c
2 changed files with 50 additions and 14 deletions
Regular → Executable
+3 -2
View File
@@ -18,7 +18,6 @@ from typing import Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
NO_ARGUMENT_TOOLS = {"get_data_config", "get_spec_templates", "stop_scan"}
@@ -56,7 +55,9 @@ async def run_client(agentic_security_url: str | None, call_tool: str | None) ->
print("Available Agentic Security MCP tools:")
for tool in tools.tools:
description_lines = (tool.description or "").strip().splitlines()
description = description_lines[0] if description_lines else "No description"
description = (
description_lines[0] if description_lines else "No description"
)
print(f"- {tool.name}: {description}")
if not call_tool:
@@ -12,8 +12,14 @@ from agentic_security.llm_providers.base import (
)
def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop",
prompt_tokens=10, completion_tokens=5, total_tokens=15):
def _mock_response(
content="Hello",
model="openai/gpt-4o-mini",
finish_reason="stop",
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
):
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = content
@@ -131,7 +137,10 @@ class TestLiteLLMProviderSync:
def test_sync_generate(self, provider):
resp = _mock_response(content="Sync response")
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.completion",
return_value=resp,
) as mock_comp:
result = provider.sync_generate("Hello")
assert result.content == snapshot("Sync response")
call_kwargs = mock_comp.call_args.kwargs
@@ -140,13 +149,19 @@ class TestLiteLLMProviderSync:
def test_sync_chat(self, provider):
resp = _mock_response(content="Chat response")
messages = [LLMMessage(role="user", content="Hi")]
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp):
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.completion",
return_value=resp,
):
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")
def test_sync_generate_with_system_prompt(self, provider):
resp = _mock_response(content="With system")
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.completion",
return_value=resp,
) as mock_comp:
result = provider.sync_generate("Hello", system_prompt="Be brief")
assert result.content == snapshot("With system")
messages = mock_comp.call_args.kwargs["messages"]
@@ -162,7 +177,11 @@ class TestLiteLLMProviderAsync:
@pytest.mark.asyncio
async def test_generate(self, provider):
resp = _mock_response(content="Async response")
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp):
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
new_callable=AsyncMock,
return_value=resp,
):
result = await provider.generate("Hello")
assert result.content == snapshot("Async response")
@@ -170,7 +189,11 @@ class TestLiteLLMProviderAsync:
async def test_chat(self, provider):
resp = _mock_response(content="Async chat")
messages = [LLMMessage(role="user", content="Hi")]
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp:
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
new_callable=AsyncMock,
return_value=resp,
) as mock_acomp:
result = await provider.chat(messages)
assert result.content == snapshot("Async chat")
call_kwargs = mock_acomp.call_args.kwargs
@@ -195,22 +218,31 @@ class TestLiteLLMProviderErrors:
provider._handle_error(Exception("something went wrong"))
def test_sync_chat_auth_error_raises_provider_error(self, provider):
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")):
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.completion",
side_effect=Exception("AuthenticationError: Invalid API key"),
):
with pytest.raises(LLMProviderError, match="Invalid API key"):
provider.sync_generate("test")
@pytest.mark.asyncio
async def test_async_chat_timeout_raises_provider_error(self, provider):
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
side_effect=Exception("Timeout: Request timed out")):
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
new_callable=AsyncMock,
side_effect=Exception("Timeout: Request timed out"),
):
with pytest.raises(LLMProviderError, match="timed out"):
await provider.generate("test")
@pytest.mark.asyncio
async def test_async_chat_model_not_found_raises_provider_error(self, provider):
provider = LiteLLMProvider(model="bad/nonexistent-model")
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
side_effect=Exception("NotFoundError: Model not found")):
with patch(
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
new_callable=AsyncMock,
side_effect=Exception("NotFoundError: Model not found"),
):
with pytest.raises(LLMProviderError, match="not found"):
await provider.generate("test")
@@ -218,16 +250,19 @@ class TestLiteLLMProviderErrors:
class TestLiteLLMProviderFactory:
def test_factory_creates_litellm_provider(self):
from agentic_security.llm_providers.factory import create_provider
provider = create_provider("litellm")
assert isinstance(provider, LiteLLMProvider)
assert provider.model == snapshot("openai/gpt-4o-mini")
def test_factory_creates_with_custom_model(self):
from agentic_security.llm_providers.factory import create_provider
provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile")
assert provider.model == snapshot("groq/llama-3.3-70b-versatile")
def test_factory_lists_litellm(self):
from agentic_security.llm_providers.factory import list_providers
providers = list_providers()
assert "litellm" in providers