Files
invariant-gateway/tests/integration/mcp/test_mcp.py
T
Luca Beurer-Kellner e18c6b5bdb Add an option to add extra metadata that is pushed and passed to Guardrails during an MCP session (#47)
* use select() before readline

* support for setting static metadata for MCP sessions

* nest extra mcp metadata in metadata object

* unify session metadata

* extra metadata tests

* use empty object as parameters, if None

* list_tools as tool call

* offset indices in tests

* test: adjust addresses

* mcp: make error reporting configurable

* line logging

* log version

* verbose logging + loud exception failure

* add server and client name to policy get

* append trace even if not pushing

* port tools/list message support to SSE

* use python -m build

* adjust guardrail failure address

* support for blocking tools/list in SSE

* use error-based failure response format by default

* tools/list test

* don't list_tools in stdio connect

* flaky test: handle second possible result in anthropic streaming case

---------

Co-authored-by: knielsen404 <kristian@invariantlabs.ai>
2025-05-19 13:44:37 +02:00

478 lines
18 KiB
Python

"""Test MCP gateway via SSE and stdio transports."""
import os
import uuid
from resources.mcp.sse.client.main import run as mcp_sse_client_run
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
from utils import create_dataset, add_guardrail_to_dataset
import pytest
import requests
from mcp.shared.exceptions import McpError
MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
MCP_SSE_SERVER_PORT = 8123
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize(
"push_to_explorer, transport",
[
(False, "stdio"),
(False, "sse"),
(True, "stdio"),
(True, "sse"),
],
)
async def test_mcp_with_gateway(
explorer_api_url,
invariant_gateway_package_whl_file,
gateway_url,
push_to_explorer,
transport,
):
"""Test MCP gateway and verify trace is pushed to explorer"""
project_name = "test-mcp-" + str(uuid.uuid4())
# Run the MCP client and make the tool call.
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=push_to_explorer,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"}
)
else:
result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=push_to_explorer,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
metadata_keys={"my-custom-key": "value1", "my-custom-key-2": "value2"},
)
assert result.isError is False
assert (
result.content[0].type == "text"
and result.content[0].text == "What is your favorite food?\n"
)
if push_to_explorer:
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
assert trace["messages"][3]["role"] == "tool"
assert trace["messages"][3]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_with_gateway_and_logging_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
"""Test MCP gateway and verify that logging guardrails work"""
project_name = "test-mcp-" + str(uuid.uuid4())
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=project_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "food in ToolOutput" if:\n (tool_output: ToolOutput)\n (chunk: str) in text(tool_output.content)\n "food" in chunk',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "get_last_message_from_user is called" if:\n (tool_call: ToolCall)\n tool_call is tool:get_last_message_from_user',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Run the MCP client and make the tool call.
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
else:
result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
assert result.isError is False
assert (
result.content[0].type == "text"
and result.content[0].text == "What is your favorite food?\n"
)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
assert trace["messages"][3]["role"] == "tool"
assert trace["messages"][3]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
# Validate the annotations
annotations = trace["annotations"]
food_annotation = None
tool_call_annotation = None
assert len(annotations) == 2
for annotation in annotations:
if (
annotation["content"] == "food in ToolOutput"
and annotation["address"] == "messages.3.content.0.text:22-26"
):
food_annotation = annotation
elif (
annotation["content"] == "get_last_message_from_user is called"
and annotation["address"] == "messages.2.tool_calls.0"
):
tool_call_annotation = annotation
assert food_annotation is not None, "Missing 'food in ToolOutput' annotation"
assert (
tool_call_annotation is not None
), "Missing 'get_last_message_from_user is called' annotation"
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
assert food_annotation["extra_metadata"]["guardrail"]["action"] == "log"
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_with_gateway_and_blocking_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
"""Test MCP gateway and verify that blocking guardrails work"""
project_name = "test-mcp-" + str(uuid.uuid4())
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=project_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "get_last_message_from_user is called" if:\n (tool_call: ToolCall)\n tool_call is tool:get_last_message_from_user',
action="block",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Run the MCP client and make the tool call.
try:
if transport == "sse":
_ = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
else:
_ = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
# If we get here, the tool call was not blocked
pytest.fail("Expected McpError to be raised")
# The tool call should be blocked by the guardrail
# and an error should be raised.
except McpError as e:
assert (
"[Invariant Guardrails] The MCP tool call was blocked for security reasons"
in e.error.message
)
assert "get_last_message_from_user is called" in e.error.message
assert e.error.code == -32600
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
# Validate the annotations
annotations = trace["annotations"]
assert len(annotations) == 1
assert (
annotations[0]["content"] == "get_last_message_from_user is called"
and annotations[0]["address"] == "messages.2.tool_calls.0"
)
assert annotations[0]["extra_metadata"]["source"] == "guardrails-error"
assert annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_sse_with_gateway_hybrid_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
"""Test MCP gateway and verify that logging and blocking guardrails work together"""
project_name = "test-mcp-" + str(uuid.uuid4())
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=project_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "get_last_message_from_user is called" if:\n (tool_call: ToolCall)\n tool_call is tool:get_last_message_from_user',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "food in ToolOutput" if:\n (tool_output: ToolOutput)\n (chunk: str) in text(tool_output.content)\n "food" in chunk',
action="block",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Run the MCP client and make the tool call.
try:
if transport == "sse":
_ = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
else:
_ = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
# If we get here, the tool call was not blocked
pytest.fail("Expected McpError to be raised")
# The tool call output should be blocked by the guardrail
# and an error should be raised.
except McpError as e:
assert (
"[Invariant Guardrails] The MCP tool call was blocked for security reasons"
in e.error.message
)
assert "food in ToolOutput" in e.error.message
assert e.error.code == -32600
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
assert trace["messages"][3]["role"] == "tool"
assert trace["messages"][3]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
# Validate the annotations
annotations = trace["annotations"]
food_annotation = None
tool_call_annotation = None
assert len(annotations) == 2
for annotation in annotations:
if (
annotation["content"] == "food in ToolOutput"
and annotation["address"] == "messages.3.content.0.text:22-26"
):
food_annotation = annotation
elif (
annotation["content"] == "get_last_message_from_user is called"
and annotation["address"] == "messages.2.tool_calls.0"
):
tool_call_annotation = annotation
assert food_annotation is not None, "Missing 'food in ToolOutput' annotation"
assert (
tool_call_annotation is not None
), "Missing 'get_last_message_from_user is called' annotation"
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
assert food_annotation["extra_metadata"]["guardrail"]["action"] == "block"
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_tool_list_blocking(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
"""
Tests that blocking guardrails work for the tools/list call.
For those, the expected behavior is that the returned tools are all renamed to blocked_... and include an informative block notice, instead of the original tool description.
"""
project_name = "test-mcp-" + str(uuid.uuid4())
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=project_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "get_last_message_from_user is called" if:\n (tool_output: ToolOutput)\n tool_call(tool_output).function.name == "tools/list"',
action="block",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Run the MCP client and make the tools/list call.
if transport == "sse":
tools_result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
)
else:
tools_result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/stdio/messenger_server/main.py",
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
)
assert "blocked_get_last_message_from_user" in str(tools_result), "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result)