From dbeb4bc660a910aa8b8028fac86bea7b51949365 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 29 Apr 2025 12:17:01 +0530 Subject: [PATCH] Add blocking guardrails based test for MCP gateway. --- tests/integration/mcp/test_mcp_stdio.py | 87 ++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 2 deletions(-) diff --git a/tests/integration/mcp/test_mcp_stdio.py b/tests/integration/mcp/test_mcp_stdio.py index de4925e..b1bd118 100644 --- a/tests/integration/mcp/test_mcp_stdio.py +++ b/tests/integration/mcp/test_mcp_stdio.py @@ -6,6 +6,7 @@ import uuid import pytest import requests +from mcp.shared.exceptions import McpError from utils import create_dataset, add_guardrail_to_dataset from resources.mcp.client.main import run as mcp_client_run @@ -19,7 +20,7 @@ async def test_mcp_stdio_with_gateway( """Test MCP gateway via stdio and verify trace is pushed to explorer""" project_name = "test-mcp-" + str(uuid.uuid4()) - # Run the MCP client and get the result + # Run the MCP client and make the tool call. result = await mcp_client_run( invariant_gateway_package_whl_file, project_name, @@ -96,7 +97,7 @@ async def test_mcp_stdio_with_gateway_and_logging_guardrails( invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), ) - # Run the MCP client and get the result + # Run the MCP client and make the tool call. result = await mcp_client_run( invariant_gateway_package_whl_file, project_name, @@ -170,3 +171,85 @@ async def test_mcp_stdio_with_gateway_and_logging_guardrails( ), "Missing 'get_last_message_from_user is called' annotation" assert food_annotation["extra_metadata"]["source"] == "guardrails-error" assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error" + + +@pytest.mark.asyncio +async def test_mcp_stdio_with_gateway_and_blocking_guardrails( + explorer_api_url, invariant_gateway_package_whl_file +): + """Test MCP gateway via stdio and verify that blocking guardrails work""" + project_name = "test-mcp-" + str(uuid.uuid4()) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=project_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "get_last_message_from_user is called" if:\n (tool_call: ToolCall)\n tool_call is tool:get_last_message_from_user', + action="block", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + # Run the MCP client and make the tool call. + try: + _ = await mcp_client_run( + invariant_gateway_package_whl_file, + project_name, + server_script_path="resources/mcp/messenger_server/main.py", + push_to_explorer=True, + tool_name="get_last_message_from_user", + tool_args={"username": "Alice"}, + ) + # The tool call should be blocked by the guardrail + # and an error should be raised. + except McpError as e: + assert ( + "[Invariant Guardrails] The MCP tool call was blocked for security reasons" + in e.error.message + ) + assert "get_last_message_from_user is called" in e.error.message + assert e.error.code == -32600 + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + metadata = trace["extra_metadata"] + assert ( + metadata["source"] == "mcp" + and metadata["mcp_client"] == "mcp" + and metadata["mcp_server"] == "messenger_server" + ) + assert trace["messages"][0]["role"] == "assistant" + assert trace["messages"][0]["tool_calls"][0]["function"] == { + "name": "get_last_message_from_user", + "arguments": {"username": "Alice"}, + } + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "get_last_message_from_user is called" + and annotations[0]["address"] == "messages.0.tool_calls.0" + ) + assert annotations[0]["extra_metadata"]["source"] == "guardrails-error"