Add blocking and logging related tests for MCP streamable HTTP route.

This commit is contained in:
Hemang
2025-05-27 22:36:44 +02:00
committed by Hemang Sarkar
parent 115ae5f36b
commit ab3fb98b67
4 changed files with 177 additions and 82 deletions
+1
View File
@@ -89,6 +89,7 @@ async def hook_tool_call(
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
print("[hook_tool_call] Guardrails result:", guardrails_result, flush=True)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
+3 -2
View File
@@ -81,8 +81,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
# Intercept the request and check for guardrails.
if not is_initialization_request:
if result := await _intercept_request(session_id, request_body) and result: # noqa: F821 pylint: disable=used-before-assignment
return result
request_interception_result = await _intercept_request(session_id, request_body)
if request_interception_result:
return request_interception_result
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
+155 -13
View File
@@ -164,7 +164,17 @@ async def test_mcp_with_gateway(
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
@pytest.mark.parametrize(
"transport",
[
"stdio",
"sse",
"streamable-json-stateless",
"streamable-json-stateful",
"streamable-sse-stateless",
"streamable-sse-stateful",
],
)
async def test_mcp_with_gateway_and_logging_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
@@ -201,7 +211,7 @@ async def test_mcp_with_gateway_and_logging_guardrails(
tool_args={"username": "Alice"},
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
else:
elif transport == "stdio":
result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
@@ -210,6 +220,16 @@ async def test_mcp_with_gateway_and_logging_guardrails(
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
else:
result = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
assert result.isError is False
assert (
@@ -232,6 +252,7 @@ async def test_mcp_with_gateway_and_logging_guardrails(
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
@@ -240,6 +261,19 @@ async def test_mcp_with_gateway_and_logging_guardrails(
)
assert "session_id" in metadata
assert "system_user" in metadata
if transport == "streamable-json-stateless":
assert metadata["server_response_type"] == "json"
assert metadata["is_stateless_http_server"] is True
elif transport == "streamable-json-stateful":
assert metadata["server_response_type"] == "json"
assert metadata["is_stateless_http_server"] is False
elif transport == "streamable-sse-stateless":
assert metadata["server_response_type"] == "sse"
assert metadata["is_stateless_http_server"] is True
elif transport == "streamable-sse-stateful":
assert metadata["server_response_type"] == "sse"
assert metadata["is_stateless_http_server"] is False
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
@@ -279,7 +313,17 @@ async def test_mcp_with_gateway_and_logging_guardrails(
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
@pytest.mark.parametrize(
"transport",
[
"stdio",
"sse",
"streamable-json-stateless",
"streamable-json-stateful",
"streamable-sse-stateless",
"streamable-sse-stateful",
],
)
async def test_mcp_with_gateway_and_blocking_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
@@ -312,7 +356,7 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
_get_mcp_sse_server_base_url(), project_name, True
),
)
else:
elif transport == "stdio":
_ = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
@@ -321,8 +365,9 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
# If we get here, the tool call was not blocked
pytest.fail("Expected McpError to be raised")
if not transport.startswith("streamable-"):
# If we get here, the tool call was not blocked
pytest.fail("Expected McpError to be raised")
# The tool call should be blocked by the guardrail
# and an error should be raised.
except McpError as e:
@@ -333,6 +378,25 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
assert "get_last_message_from_user is called" in e.error.message
assert e.error.code == -32600
if transport.startswith("streamable-"):
with pytest.raises(ExceptionGroup) as exc_group:
_ = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
# Extract the actual HTTPStatusError
http_errors = [
e
for e in exc_group.value.exceptions
if isinstance(e, httpx.HTTPStatusError)
]
assert http_errors[0].response.status_code == 400
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
@@ -375,7 +439,17 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
@pytest.mark.parametrize(
"transport",
[
"stdio",
"sse",
"streamable-json-stateless",
"streamable-json-stateful",
"streamable-sse-stateless",
"streamable-sse-stateful",
],
)
async def test_mcp_with_gateway_hybrid_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
@@ -416,7 +490,7 @@ async def test_mcp_with_gateway_hybrid_guardrails(
_get_mcp_sse_server_base_url(), project_name, True
),
)
else:
elif transport == "stdio":
_ = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
@@ -425,8 +499,9 @@ async def test_mcp_with_gateway_hybrid_guardrails(
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
# If we get here, the tool call was not blocked
pytest.fail("Expected McpError to be raised")
if not transport.startswith("streamable-"):
# If we get here, the tool call was not blocked
pytest.fail("Expected McpError to be raised")
# The tool call output should be blocked by the guardrail
# and an error should be raised.
except McpError as e:
@@ -437,6 +512,34 @@ async def test_mcp_with_gateway_hybrid_guardrails(
assert "food in ToolOutput" in e.error.message
assert e.error.code == -32600
if transport.startswith("streamable-"):
with pytest.raises(ExceptionGroup) as exc_group:
_ = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
if transport.startswith("streamable-json"):
# Extract the actual HTTPStatusError
http_errors = [
e
for e in exc_group.value.exceptions
if isinstance(e, httpx.HTTPStatusError)
]
assert http_errors[0].response.status_code == 400
else:
mcp_error = [e for e in exc_group.value.exceptions][0].exceptions[0]
assert (
"[Invariant Guardrails] The MCP tool call was blocked for security reasons"
in mcp_error.error.message
)
assert "food in ToolOutput" in mcp_error.error.message
assert -32600 == mcp_error.error.code
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
@@ -499,7 +602,17 @@ async def test_mcp_with_gateway_hybrid_guardrails(
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
@pytest.mark.parametrize(
"transport",
[
"stdio",
"sse",
"streamable-json-stateless",
"streamable-json-stateful",
"streamable-sse-stateless",
"streamable-sse-stateful",
],
)
async def test_mcp_tool_list_blocking(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
@@ -524,6 +637,26 @@ async def test_mcp_tool_list_blocking(
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
if transport.startswith("streamable-json"):
with pytest.raises(ExceptionGroup) as exc_group:
_ = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
# Extract the actual HTTPStatusError
http_errors = [
e
for e in exc_group.value.exceptions
if isinstance(e, httpx.HTTPStatusError)
]
assert http_errors[0].response.status_code == 400
return
# Run the MCP client and make the tools/list call.
if transport == "sse":
tools_result = await mcp_sse_client_run(
@@ -533,7 +666,7 @@ async def test_mcp_tool_list_blocking(
tool_args={},
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
)
else:
elif transport == "stdio":
tools_result = await mcp_stdio_client_run(
invariant_gateway_package_whl_file,
project_name,
@@ -542,7 +675,16 @@ async def test_mcp_tool_list_blocking(
tool_name="tools/list",
tool_args={},
)
else:
tools_result = await mcp_streamable_client_run(
gateway_url + "/api/v1/gateway/mcp/streamable",
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
headers=_get_headers(
_get_streamable_server_base_url(transport), project_name, True
),
)
assert "blocked_get_last_message_from_user" in str(tools_result), (
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
+ str(tools_result)
@@ -1,69 +1,14 @@
"""This is a simple example of how to use the MCP client with Streamable HTTP transport."""
# pylint: disable=E1101
# pylint: disable=W0201
# pylint: disable=C2801
import asyncio
from datetime import timedelta
from typing import Any, Optional
from contextlib import AsyncExitStack
from typing import Any
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
class MCPClient:
"""MCP Client for interacting with a MCP Streamable HTTP server and processing queries"""
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self._streams_context = None # Initialize these to None
self._session_context = None # so they always exist
async def connect_to_streamable_server(
self, server_url: str, headers: Optional[dict] = None
):
"""
Connect to an MCP server running with Streamable HTTP transport
Args:
server_url: URL of the MCP server
headers: Optional headers to include in the request
"""
# Store the context managers so they stay alive
self._streams_context = streamablehttp_client(
url=server_url,
headers=headers or {},
timeout=timedelta(seconds=5),
sse_read_timeout=timedelta(seconds=10),
)
read_stream, write_stream, session_id = await self._streams_context.__aenter__()
self.session_id = session_id
self._session_context = ClientSession(read_stream, write_stream)
self.session: ClientSession = await self._session_context.__aenter__()
await self.session.initialize()
async def cleanup(self):
"""Properly clean up the session and streams"""
if self._session_context:
await self._session_context.__aexit__(None, None, None)
if self._streams_context:
await self._streams_context.__aexit__(None, None, None)
async def process_query(self, tool_name: str, tool_args: dict) -> str:
"""Process a query using MCP server"""
result = await self.session.call_tool(
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
)
return result
async def run(
gateway_url: str,
push_to_explorer: bool,
@@ -81,21 +26,27 @@ async def run(
tool_args: Arguments for the tool call
"""
client = MCPClient()
try:
await client.connect_to_streamable_server(
server_url=gateway_url, headers=headers or {}
streams_context = streamablehttp_client(
url=gateway_url,
headers=headers or {},
timeout=timedelta(seconds=5),
sse_read_timeout=timedelta(seconds=10),
)
# list tools
listed_tools = await client.session.list_tools()
# call tool
if tool_name == "tools/list":
return listed_tools
else:
return await client.process_query(tool_name, tool_args)
async with streams_context as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# list tools
listed_tools = await session.list_tools()
# call tool
if tool_name == "tools/list":
return listed_tools
else:
return await session.call_tool(
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
)
finally:
# Sleep for a while to allow the server to process the background tasks
# like pushing traces to the explorer
if push_to_explorer:
await asyncio.sleep(2)
await client.cleanup()