mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-26 00:17:47 +02:00
Add blocking guardrails based test for MCP gateway.
This commit is contained in:
@@ -6,6 +6,7 @@ import uuid
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from utils import create_dataset, add_guardrail_to_dataset
|
||||
|
||||
from resources.mcp.client.main import run as mcp_client_run
|
||||
@@ -19,7 +20,7 @@ async def test_mcp_stdio_with_gateway(
|
||||
"""Test MCP gateway via stdio and verify trace is pushed to explorer"""
|
||||
project_name = "test-mcp-" + str(uuid.uuid4())
|
||||
|
||||
# Run the MCP client and get the result
|
||||
# Run the MCP client and make the tool call.
|
||||
result = await mcp_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
@@ -96,7 +97,7 @@ async def test_mcp_stdio_with_gateway_and_logging_guardrails(
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
# Run the MCP client and get the result
|
||||
# Run the MCP client and make the tool call.
|
||||
result = await mcp_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
@@ -170,3 +171,85 @@ async def test_mcp_stdio_with_gateway_and_logging_guardrails(
|
||||
), "Missing 'get_last_message_from_user is called' annotation"
|
||||
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
|
||||
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_stdio_with_gateway_and_blocking_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file
|
||||
):
|
||||
"""Test MCP gateway via stdio and verify that blocking guardrails work"""
|
||||
project_name = "test-mcp-" + str(uuid.uuid4())
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=project_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "get_last_message_from_user is called" if:\n (tool_call: ToolCall)\n tool_call is tool:get_last_message_from_user',
|
||||
action="block",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
# Run the MCP client and make the tool call.
|
||||
try:
|
||||
_ = await mcp_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
server_script_path="resources/mcp/messenger_server/main.py",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
)
|
||||
# The tool call should be blocked by the guardrail
|
||||
# and an error should be raised.
|
||||
except McpError as e:
|
||||
assert (
|
||||
"[Invariant Guardrails] The MCP tool call was blocked for security reasons"
|
||||
in e.error.message
|
||||
)
|
||||
assert "get_last_message_from_user is called" in e.error.message
|
||||
assert e.error.code == -32600
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
metadata = trace["extra_metadata"]
|
||||
assert (
|
||||
metadata["source"] == "mcp"
|
||||
and metadata["mcp_client"] == "mcp"
|
||||
and metadata["mcp_server"] == "messenger_server"
|
||||
)
|
||||
assert trace["messages"][0]["role"] == "assistant"
|
||||
assert trace["messages"][0]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
"arguments": {"username": "Alice"},
|
||||
}
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "get_last_message_from_user is called"
|
||||
and annotations[0]["address"] == "messages.0.tool_calls.0"
|
||||
)
|
||||
assert annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
|
||||
Reference in New Issue
Block a user