diff --git a/agentic_security/mcp/client.py b/agentic_security/mcp/client.py index 79c3758..d7c9b2e 100644 --- a/agentic_security/mcp/client.py +++ b/agentic_security/mcp/client.py @@ -11,13 +11,13 @@ server_params = StdioServerParameters( ) -async def run(): +async def run()-> None: async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: - # Initialize the connection + # Initialize the connection --> connection doesnt work await session.initialize() - # List available prompts, resources, and tools + # List available prompts, resources, and tools --> no avalialbe tools prompts = await session.list_prompts() print(f"Available prompts: {prompts}") @@ -27,7 +27,7 @@ async def run(): tools = await session.list_tools() print(f"Available tools: {tools}") - # Call the echo tool + # Call the echo tool --> echo tool iisue echo_result = await session.call_tool( "echo_tool", arguments={"message": "Hello from client!"} ) diff --git a/agentic_security/mcp/main.py b/agentic_security/mcp/main.py index 0e754dd..64a7268 100644 --- a/agentic_security/mcp/main.py +++ b/agentic_security/mcp/main.py @@ -14,7 +14,15 @@ AGENTIC_SECURITY = "http://0.0.0.0:8718" @mcp.tool() async def verify_llm(spec: str) -> dict: - """Verify an LLM model specification using the FastAPI server.""" + """ + Verify an LLM model specification using the FastAPI server + + Returns: + dict: containing the verification result form the FastAPI server + + Args: spect(str): The specification of the LLM model to verify. + + """ url = f"{AGENTIC_SECURITY}/verify" async with httpx.AsyncClient() as client: response = await client.post(url, json={"spec": spec}) @@ -28,7 +36,18 @@ async def start_scan( optimize: bool = False, enableMultiStepAttack: bool = False, ) -> dict: - """Start an LLM security scan via the FastAPI server.""" + """ + Start an LLM security scan via the FastAPI server. + Returns: + dict: The scan initiation result from the FastAPI server. + + Args: + llmSpec (str): The specification of the LLM model. + maxBudget (int): The maximum budget for the scan. + optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False. + enableMultiStepAttack (bool, optional): Whether to enable multi-step attack + + """ url = f"{AGENTIC_SECURITY}/scan" payload = { "llmSpec": llmSpec, @@ -46,7 +65,11 @@ async def start_scan( @mcp.tool() async def stop_scan() -> dict: - """Stop an ongoing scan via the FastAPI server.""" + """Stop an ongoing scan via the FastAPI server. + + Returns: + dict: The confirmation from the FastAPI server that the scan has been stopped. + """ url = f"{AGENTIC_SECURITY}/stop" async with httpx.AsyncClient() as client: response = await client.post(url) @@ -55,7 +78,12 @@ async def stop_scan() -> dict: @mcp.tool() async def get_data_config() -> list: - """Retrieve data configuration from the FastAPI server.""" + """ + Retrieve data configuration from the FastAPI server. + + Returns: + list: The response from the FastAPI server, confirming the scan has been stopped. + """ url = f"{AGENTIC_SECURITY}/v1/data-config" async with httpx.AsyncClient() as client: response = await client.get(url) @@ -64,7 +92,12 @@ async def get_data_config() -> list: @mcp.tool() async def get_spec_templates() -> list: - """Retrieve data configuration from the FastAPI server.""" + """ + Retrieve data configuration from the FastAPI server. + + Returns: + list: The LLM specification templates from the FastAPI server. + """ url = f"{AGENTIC_SECURITY}/v1/llm-specs" async with httpx.AsyncClient() as client: response = await client.get(url) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..fdf7ec0 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,76 @@ +import pytest +import asyncio +from unittest.mock import AsyncMock, patch +from agentic_security.mcp import ClientSession +from ..agentic_security.mcp.client import run + +# Fixtures +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + +@pytest.fixture +async def mock_session(): + with patch('mcp.client.stdio.stdio_client') as mock_client: + + mock_read = AsyncMock() + mock_write = AsyncMock() + + # Configures mock client such that mock responses are returned + mock_client.return_value.__aenter__.return_value = (mock_read, mock_write) + + # Creates a mock session + mock_session = AsyncMock(spec=ClientSession) + + # Expected responses + mock_session.initialize = AsyncMock() + mock_session.list_prompts = AsyncMock(return_value=["test_prompt"]) + mock_session.list_resources = AsyncMock(return_value=["test_resource"]) + mock_session.list_tools = AsyncMock(return_value=["echo_tool"]) + mock_session.call_tool = AsyncMock(return_value="Hello from client!") + + with patch('mcp.ClientSession', return_value=mock_session): + yield mock_session + +# Tests + +@pytest.mark.asyncio +async def test_initialization(mock_session): + """Test initialization success and failure cases""" + # Test initialization success case + await run() + mock_session.initialize.assert_called_once() + + # Resetting the mock to test for failure case + mock_session.initialize.reset_mock() + mock_session.initialize.side_effect = ConnectionError("Failed to connect") + + # Test connection error + with pytest.raises(ConnectionError): + await run() + +@pytest.mark.asyncio +async def test_list_resources(mock_session): + """Test listing available resources""" + await run() + mock_session.list_resources.assert_called_once() + assert await mock_session.list_resources() == ["test_resource"] + +@pytest.mark.asyncio +async def test_list_tools(mock_session): + """Test listing available tools""" + await run() + mock_session.list_tools.assert_called_once() + assert await mock_session.list_tools() == ["echo_tool"] + +@pytest.mark.asyncio +async def test_echo_tool(mock_session): + """Test the echo tool functionality""" + await run() + mock_session.call_tool.assert_called_once_with( + "echo_tool", + arguments={"message": "Hello from client!"} + ) \ No newline at end of file