From 0dc738a11ef2734da3cb28b675e3adc44b65874c Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Thu, 3 Apr 2025 20:43:53 +0300 Subject: [PATCH] fix(pc): --- agentic_security/mcp/client.py | 8 +++---- agentic_security/mcp/main.py | 16 ++++++------- tests/test_client.py | 44 ++++++++++++++++++++-------------- 3 files changed, 38 insertions(+), 30 deletions(-) diff --git a/agentic_security/mcp/client.py b/agentic_security/mcp/client.py index d7c9b2e..e328c6a 100644 --- a/agentic_security/mcp/client.py +++ b/agentic_security/mcp/client.py @@ -11,13 +11,13 @@ server_params = StdioServerParameters( ) -async def run()-> None: +async def run() -> None: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: - # Initialize the connection --> connection doesnt work + # Initialize the connection --> connection does not work await session.initialize() - # List available prompts, resources, and tools --> no avalialbe tools + # List available prompts, resources, and tools --> no avalialbe tools prompts = await session.list_prompts() print(f"Available prompts: {prompts}") @@ -27,7 +27,7 @@ async def run()-> None: tools = await session.list_tools() print(f"Available tools: {tools}") - # Call the echo tool --> echo tool iisue + # Call the echo tool --> echo tool iisue echo_result = await session.call_tool( "echo_tool", arguments={"message": "Hello from client!"} ) diff --git a/agentic_security/mcp/main.py b/agentic_security/mcp/main.py index 64a7268..9f4efed 100644 --- a/agentic_security/mcp/main.py +++ b/agentic_security/mcp/main.py @@ -17,11 +17,11 @@ async def verify_llm(spec: str) -> dict: """ Verify an LLM model specification using the FastAPI server - Returns: - dict: containing the verification result form the FastAPI server - + Returns: + dict: containing the verification result form the FastAPI server + Args: spect(str): The specification of the LLM model to verify. - + """ url = f"{AGENTIC_SECURITY}/verify" async with httpx.AsyncClient() as client: @@ -40,13 +40,13 @@ async def start_scan( Start an LLM security scan via the FastAPI server. Returns: dict: The scan initiation result from the FastAPI server. - - Args: + + Args: llmSpec (str): The specification of the LLM model. maxBudget (int): The maximum budget for the scan. optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False. - enableMultiStepAttack (bool, optional): Whether to enable multi-step attack - + enableMultiStepAttack (bool, optional): Whether to enable multi-step attack + """ url = f"{AGENTIC_SECURITY}/scan" payload = { diff --git a/tests/test_client.py b/tests/test_client.py index fdf7ec0..fabcc31 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,10 +1,14 @@ -import pytest import asyncio from unittest.mock import AsyncMock, patch + +import pytest + from agentic_security.mcp import ClientSession + from ..agentic_security.mcp.client import run -# Fixtures + +# Fixtures @pytest.fixture(scope="session") def event_loop(): """Create an instance of the default event loop for each test case.""" @@ -12,46 +16,49 @@ def event_loop(): yield loop loop.close() + @pytest.fixture async def mock_session(): - with patch('mcp.client.stdio.stdio_client') as mock_client: - + with patch("mcp.client.stdio.stdio_client") as mock_client: mock_read = AsyncMock() mock_write = AsyncMock() - - # Configures mock client such that mock responses are returned + + # Configures mock client such that mock responses are returned mock_client.return_value.__aenter__.return_value = (mock_read, mock_write) - + # Creates a mock session mock_session = AsyncMock(spec=ClientSession) - + # Expected responses mock_session.initialize = AsyncMock() mock_session.list_prompts = AsyncMock(return_value=["test_prompt"]) mock_session.list_resources = AsyncMock(return_value=["test_resource"]) mock_session.list_tools = AsyncMock(return_value=["echo_tool"]) mock_session.call_tool = AsyncMock(return_value="Hello from client!") - - with patch('mcp.ClientSession', return_value=mock_session): + + with patch("mcp.ClientSession", return_value=mock_session): yield mock_session -# Tests + +# Tests + @pytest.mark.asyncio async def test_initialization(mock_session): """Test initialization success and failure cases""" - # Test initialization success case + # Test initialization success case await run() mock_session.initialize.assert_called_once() - - # Resetting the mock to test for failure case + + # Resetting the mock to test for failure case mock_session.initialize.reset_mock() mock_session.initialize.side_effect = ConnectionError("Failed to connect") - + # Test connection error with pytest.raises(ConnectionError): await run() + @pytest.mark.asyncio async def test_list_resources(mock_session): """Test listing available resources""" @@ -59,6 +66,7 @@ async def test_list_resources(mock_session): mock_session.list_resources.assert_called_once() assert await mock_session.list_resources() == ["test_resource"] + @pytest.mark.asyncio async def test_list_tools(mock_session): """Test listing available tools""" @@ -66,11 +74,11 @@ async def test_list_tools(mock_session): mock_session.list_tools.assert_called_once() assert await mock_session.list_tools() == ["echo_tool"] + @pytest.mark.asyncio async def test_echo_tool(mock_session): """Test the echo tool functionality""" await run() mock_session.call_tool.assert_called_once_with( - "echo_tool", - arguments={"message": "Hello from client!"} - ) \ No newline at end of file + "echo_tool", arguments={"message": "Hello from client!"} + )