mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(pc):
This commit is contained in:
Regular → Executable
+3
-2
@@ -18,7 +18,6 @@ from typing import Any
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
NO_ARGUMENT_TOOLS = {"get_data_config", "get_spec_templates", "stop_scan"}
|
||||
|
||||
|
||||
@@ -56,7 +55,9 @@ async def run_client(agentic_security_url: str | None, call_tool: str | None) ->
|
||||
print("Available Agentic Security MCP tools:")
|
||||
for tool in tools.tools:
|
||||
description_lines = (tool.description or "").strip().splitlines()
|
||||
description = description_lines[0] if description_lines else "No description"
|
||||
description = (
|
||||
description_lines[0] if description_lines else "No description"
|
||||
)
|
||||
print(f"- {tool.name}: {description}")
|
||||
|
||||
if not call_tool:
|
||||
|
||||
@@ -12,8 +12,14 @@ from agentic_security.llm_providers.base import (
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop",
|
||||
prompt_tokens=10, completion_tokens=5, total_tokens=15):
|
||||
def _mock_response(
|
||||
content="Hello",
|
||||
model="openai/gpt-4o-mini",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
):
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = content
|
||||
@@ -131,7 +137,10 @@ class TestLiteLLMProviderSync:
|
||||
|
||||
def test_sync_generate(self, provider):
|
||||
resp = _mock_response(content="Sync response")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.completion",
|
||||
return_value=resp,
|
||||
) as mock_comp:
|
||||
result = provider.sync_generate("Hello")
|
||||
assert result.content == snapshot("Sync response")
|
||||
call_kwargs = mock_comp.call_args.kwargs
|
||||
@@ -140,13 +149,19 @@ class TestLiteLLMProviderSync:
|
||||
def test_sync_chat(self, provider):
|
||||
resp = _mock_response(content="Chat response")
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp):
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.completion",
|
||||
return_value=resp,
|
||||
):
|
||||
result = provider.sync_chat(messages)
|
||||
assert result.content == snapshot("Chat response")
|
||||
|
||||
def test_sync_generate_with_system_prompt(self, provider):
|
||||
resp = _mock_response(content="With system")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.completion",
|
||||
return_value=resp,
|
||||
) as mock_comp:
|
||||
result = provider.sync_generate("Hello", system_prompt="Be brief")
|
||||
assert result.content == snapshot("With system")
|
||||
messages = mock_comp.call_args.kwargs["messages"]
|
||||
@@ -162,7 +177,11 @@ class TestLiteLLMProviderAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate(self, provider):
|
||||
resp = _mock_response(content="Async response")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp):
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
|
||||
new_callable=AsyncMock,
|
||||
return_value=resp,
|
||||
):
|
||||
result = await provider.generate("Hello")
|
||||
assert result.content == snapshot("Async response")
|
||||
|
||||
@@ -170,7 +189,11 @@ class TestLiteLLMProviderAsync:
|
||||
async def test_chat(self, provider):
|
||||
resp = _mock_response(content="Async chat")
|
||||
messages = [LLMMessage(role="user", content="Hi")]
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp:
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
|
||||
new_callable=AsyncMock,
|
||||
return_value=resp,
|
||||
) as mock_acomp:
|
||||
result = await provider.chat(messages)
|
||||
assert result.content == snapshot("Async chat")
|
||||
call_kwargs = mock_acomp.call_args.kwargs
|
||||
@@ -195,22 +218,31 @@ class TestLiteLLMProviderErrors:
|
||||
provider._handle_error(Exception("something went wrong"))
|
||||
|
||||
def test_sync_chat_auth_error_raises_provider_error(self, provider):
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")):
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.completion",
|
||||
side_effect=Exception("AuthenticationError: Invalid API key"),
|
||||
):
|
||||
with pytest.raises(LLMProviderError, match="Invalid API key"):
|
||||
provider.sync_generate("test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_timeout_raises_provider_error(self, provider):
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
|
||||
side_effect=Exception("Timeout: Request timed out")):
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Timeout: Request timed out"),
|
||||
):
|
||||
with pytest.raises(LLMProviderError, match="timed out"):
|
||||
await provider.generate("test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_model_not_found_raises_provider_error(self, provider):
|
||||
provider = LiteLLMProvider(model="bad/nonexistent-model")
|
||||
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
|
||||
side_effect=Exception("NotFoundError: Model not found")):
|
||||
with patch(
|
||||
"agentic_security.llm_providers.litellm_provider.litellm.acompletion",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("NotFoundError: Model not found"),
|
||||
):
|
||||
with pytest.raises(LLMProviderError, match="not found"):
|
||||
await provider.generate("test")
|
||||
|
||||
@@ -218,16 +250,19 @@ class TestLiteLLMProviderErrors:
|
||||
class TestLiteLLMProviderFactory:
|
||||
def test_factory_creates_litellm_provider(self):
|
||||
from agentic_security.llm_providers.factory import create_provider
|
||||
|
||||
provider = create_provider("litellm")
|
||||
assert isinstance(provider, LiteLLMProvider)
|
||||
assert provider.model == snapshot("openai/gpt-4o-mini")
|
||||
|
||||
def test_factory_creates_with_custom_model(self):
|
||||
from agentic_security.llm_providers.factory import create_provider
|
||||
|
||||
provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile")
|
||||
assert provider.model == snapshot("groq/llama-3.3-70b-versatile")
|
||||
|
||||
def test_factory_lists_litellm(self):
|
||||
from agentic_security.llm_providers.factory import list_providers
|
||||
|
||||
providers = list_providers()
|
||||
assert "litellm" in providers
|
||||
|
||||
Reference in New Issue
Block a user