diff --git a/gateway/common/mcp_utils.py b/gateway/common/mcp_utils.py index da7322f..c6461e9 100644 --- a/gateway/common/mcp_utils.py +++ b/gateway/common/mcp_utils.py @@ -89,6 +89,7 @@ async def hook_tool_call( guardrails_result = await session.get_guardrails_check_result( message, action=GuardrailAction.BLOCK ) + print("[hook_tool_call] Guardrails result:", guardrails_result, flush=True) # If the request is blocked, return a message indicating the block reason. # If there are new errors, run append_and_push_trace in background. # If there are no new errors, just return the original request. diff --git a/gateway/routes/mcp_streamable.py b/gateway/routes/mcp_streamable.py index 4c41233..6016ff6 100644 --- a/gateway/routes/mcp_streamable.py +++ b/gateway/routes/mcp_streamable.py @@ -81,8 +81,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse: # Intercept the request and check for guardrails. if not is_initialization_request: - if result := await _intercept_request(session_id, request_body) and result: # noqa: F821 pylint: disable=used-before-assignment - return result + request_interception_result = await _intercept_request(session_id, request_body) + if request_interception_result: + return request_interception_result async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: try: diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index be81849..4b695c0 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -164,7 +164,17 @@ async def test_mcp_with_gateway( @pytest.mark.asyncio @pytest.mark.timeout(30) -@pytest.mark.parametrize("transport", ["stdio", "sse"]) +@pytest.mark.parametrize( + "transport", + [ + "stdio", + "sse", + "streamable-json-stateless", + "streamable-json-stateful", + "streamable-sse-stateless", + "streamable-sse-stateful", + ], +) async def test_mcp_with_gateway_and_logging_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport ): @@ -201,7 +211,7 @@ async def test_mcp_with_gateway_and_logging_guardrails( tool_args={"username": "Alice"}, headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True), ) - else: + elif transport == "stdio": result = await mcp_stdio_client_run( invariant_gateway_package_whl_file, project_name, @@ -210,6 +220,16 @@ async def test_mcp_with_gateway_and_logging_guardrails( tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, ) + else: + result = await mcp_streamable_client_run( + gateway_url + "/api/v1/gateway/mcp/streamable", + push_to_explorer=True, + tool_name="get_last_message_from_user", + tool_args={"username": "Alice"}, + headers=_get_headers( + _get_streamable_server_base_url(transport), project_name, True + ), + ) assert result.isError is False assert ( @@ -232,6 +252,7 @@ async def test_mcp_with_gateway_and_logging_guardrails( timeout=5, ) trace = trace_response.json() + metadata = trace["extra_metadata"] assert ( metadata["source"] == "mcp" @@ -240,6 +261,19 @@ async def test_mcp_with_gateway_and_logging_guardrails( ) assert "session_id" in metadata assert "system_user" in metadata + if transport == "streamable-json-stateless": + assert metadata["server_response_type"] == "json" + assert metadata["is_stateless_http_server"] is True + elif transport == "streamable-json-stateful": + assert metadata["server_response_type"] == "json" + assert metadata["is_stateless_http_server"] is False + elif transport == "streamable-sse-stateless": + assert metadata["server_response_type"] == "sse" + assert metadata["is_stateless_http_server"] is True + elif transport == "streamable-sse-stateful": + assert metadata["server_response_type"] == "sse" + assert metadata["is_stateless_http_server"] is False + assert trace["messages"][2]["role"] == "assistant" assert trace["messages"][2]["tool_calls"][0]["function"] == { "name": "get_last_message_from_user", @@ -279,7 +313,17 @@ async def test_mcp_with_gateway_and_logging_guardrails( @pytest.mark.asyncio @pytest.mark.timeout(30) -@pytest.mark.parametrize("transport", ["stdio", "sse"]) +@pytest.mark.parametrize( + "transport", + [ + "stdio", + "sse", + "streamable-json-stateless", + "streamable-json-stateful", + "streamable-sse-stateless", + "streamable-sse-stateful", + ], +) async def test_mcp_with_gateway_and_blocking_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport ): @@ -312,7 +356,7 @@ async def test_mcp_with_gateway_and_blocking_guardrails( _get_mcp_sse_server_base_url(), project_name, True ), ) - else: + elif transport == "stdio": _ = await mcp_stdio_client_run( invariant_gateway_package_whl_file, project_name, @@ -321,8 +365,9 @@ async def test_mcp_with_gateway_and_blocking_guardrails( tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, ) - # If we get here, the tool call was not blocked - pytest.fail("Expected McpError to be raised") + if not transport.startswith("streamable-"): + # If we get here, the tool call was not blocked + pytest.fail("Expected McpError to be raised") # The tool call should be blocked by the guardrail # and an error should be raised. except McpError as e: @@ -333,6 +378,25 @@ async def test_mcp_with_gateway_and_blocking_guardrails( assert "get_last_message_from_user is called" in e.error.message assert e.error.code == -32600 + if transport.startswith("streamable-"): + with pytest.raises(ExceptionGroup) as exc_group: + _ = await mcp_streamable_client_run( + gateway_url + "/api/v1/gateway/mcp/streamable", + push_to_explorer=True, + tool_name="get_last_message_from_user", + tool_args={"username": "Alice"}, + headers=_get_headers( + _get_streamable_server_base_url(transport), project_name, True + ), + ) + # Extract the actual HTTPStatusError + http_errors = [ + e + for e in exc_group.value.exceptions + if isinstance(e, httpx.HTTPStatusError) + ] + assert http_errors[0].response.status_code == 400 + # Fetch the trace ids for the dataset traces_response = requests.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces", @@ -375,7 +439,17 @@ async def test_mcp_with_gateway_and_blocking_guardrails( @pytest.mark.asyncio @pytest.mark.timeout(30) -@pytest.mark.parametrize("transport", ["stdio", "sse"]) +@pytest.mark.parametrize( + "transport", + [ + "stdio", + "sse", + "streamable-json-stateless", + "streamable-json-stateful", + "streamable-sse-stateless", + "streamable-sse-stateful", + ], +) async def test_mcp_with_gateway_hybrid_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport ): @@ -416,7 +490,7 @@ async def test_mcp_with_gateway_hybrid_guardrails( _get_mcp_sse_server_base_url(), project_name, True ), ) - else: + elif transport == "stdio": _ = await mcp_stdio_client_run( invariant_gateway_package_whl_file, project_name, @@ -425,8 +499,9 @@ async def test_mcp_with_gateway_hybrid_guardrails( tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, ) - # If we get here, the tool call was not blocked - pytest.fail("Expected McpError to be raised") + if not transport.startswith("streamable-"): + # If we get here, the tool call was not blocked + pytest.fail("Expected McpError to be raised") # The tool call output should be blocked by the guardrail # and an error should be raised. except McpError as e: @@ -437,6 +512,34 @@ async def test_mcp_with_gateway_hybrid_guardrails( assert "food in ToolOutput" in e.error.message assert e.error.code == -32600 + if transport.startswith("streamable-"): + with pytest.raises(ExceptionGroup) as exc_group: + _ = await mcp_streamable_client_run( + gateway_url + "/api/v1/gateway/mcp/streamable", + push_to_explorer=True, + tool_name="get_last_message_from_user", + tool_args={"username": "Alice"}, + headers=_get_headers( + _get_streamable_server_base_url(transport), project_name, True + ), + ) + if transport.startswith("streamable-json"): + # Extract the actual HTTPStatusError + http_errors = [ + e + for e in exc_group.value.exceptions + if isinstance(e, httpx.HTTPStatusError) + ] + assert http_errors[0].response.status_code == 400 + else: + mcp_error = [e for e in exc_group.value.exceptions][0].exceptions[0] + assert ( + "[Invariant Guardrails] The MCP tool call was blocked for security reasons" + in mcp_error.error.message + ) + assert "food in ToolOutput" in mcp_error.error.message + assert -32600 == mcp_error.error.code + # Fetch the trace ids for the dataset traces_response = requests.get( f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces", @@ -499,7 +602,17 @@ async def test_mcp_with_gateway_hybrid_guardrails( @pytest.mark.asyncio @pytest.mark.timeout(30) -@pytest.mark.parametrize("transport", ["stdio", "sse"]) +@pytest.mark.parametrize( + "transport", + [ + "stdio", + "sse", + "streamable-json-stateless", + "streamable-json-stateful", + "streamable-sse-stateless", + "streamable-sse-stateful", + ], +) async def test_mcp_tool_list_blocking( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport ): @@ -524,6 +637,26 @@ async def test_mcp_tool_list_blocking( invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), ) + if transport.startswith("streamable-json"): + with pytest.raises(ExceptionGroup) as exc_group: + _ = await mcp_streamable_client_run( + gateway_url + "/api/v1/gateway/mcp/streamable", + push_to_explorer=True, + tool_name="tools/list", + tool_args={}, + headers=_get_headers( + _get_streamable_server_base_url(transport), project_name, True + ), + ) + # Extract the actual HTTPStatusError + http_errors = [ + e + for e in exc_group.value.exceptions + if isinstance(e, httpx.HTTPStatusError) + ] + assert http_errors[0].response.status_code == 400 + return + # Run the MCP client and make the tools/list call. if transport == "sse": tools_result = await mcp_sse_client_run( @@ -533,7 +666,7 @@ async def test_mcp_tool_list_blocking( tool_args={}, headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True), ) - else: + elif transport == "stdio": tools_result = await mcp_stdio_client_run( invariant_gateway_package_whl_file, project_name, @@ -542,7 +675,16 @@ async def test_mcp_tool_list_blocking( tool_name="tools/list", tool_args={}, ) - + else: + tools_result = await mcp_streamable_client_run( + gateway_url + "/api/v1/gateway/mcp/streamable", + push_to_explorer=True, + tool_name="tools/list", + tool_args={}, + headers=_get_headers( + _get_streamable_server_base_url(transport), project_name, True + ), + ) assert "blocked_get_last_message_from_user" in str(tools_result), ( "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result) diff --git a/tests/integration/resources/mcp/streamable/client/main.py b/tests/integration/resources/mcp/streamable/client/main.py index 0525e00..a498e5b 100644 --- a/tests/integration/resources/mcp/streamable/client/main.py +++ b/tests/integration/resources/mcp/streamable/client/main.py @@ -1,69 +1,14 @@ """This is a simple example of how to use the MCP client with Streamable HTTP transport.""" -# pylint: disable=E1101 -# pylint: disable=W0201 -# pylint: disable=C2801 - import asyncio from datetime import timedelta -from typing import Any, Optional -from contextlib import AsyncExitStack +from typing import Any from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -class MCPClient: - """MCP Client for interacting with a MCP Streamable HTTP server and processing queries""" - - def __init__(self): - # Initialize session and client objects - self.session: Optional[ClientSession] = None - self.exit_stack = AsyncExitStack() - self._streams_context = None # Initialize these to None - self._session_context = None # so they always exist - - async def connect_to_streamable_server( - self, server_url: str, headers: Optional[dict] = None - ): - """ - Connect to an MCP server running with Streamable HTTP transport - - Args: - server_url: URL of the MCP server - headers: Optional headers to include in the request - """ - # Store the context managers so they stay alive - self._streams_context = streamablehttp_client( - url=server_url, - headers=headers or {}, - timeout=timedelta(seconds=5), - sse_read_timeout=timedelta(seconds=10), - ) - read_stream, write_stream, session_id = await self._streams_context.__aenter__() - self.session_id = session_id - - self._session_context = ClientSession(read_stream, write_stream) - self.session: ClientSession = await self._session_context.__aenter__() - - await self.session.initialize() - - async def cleanup(self): - """Properly clean up the session and streams""" - if self._session_context: - await self._session_context.__aexit__(None, None, None) - if self._streams_context: - await self._streams_context.__aexit__(None, None, None) - - async def process_query(self, tool_name: str, tool_args: dict) -> str: - """Process a query using MCP server""" - result = await self.session.call_tool( - tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10) - ) - return result - - async def run( gateway_url: str, push_to_explorer: bool, @@ -81,21 +26,27 @@ async def run( tool_args: Arguments for the tool call """ - client = MCPClient() try: - await client.connect_to_streamable_server( - server_url=gateway_url, headers=headers or {} + streams_context = streamablehttp_client( + url=gateway_url, + headers=headers or {}, + timeout=timedelta(seconds=5), + sse_read_timeout=timedelta(seconds=10), ) - # list tools - listed_tools = await client.session.list_tools() - # call tool - if tool_name == "tools/list": - return listed_tools - else: - return await client.process_query(tool_name, tool_args) + async with streams_context as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + # list tools + listed_tools = await session.list_tools() + # call tool + if tool_name == "tools/list": + return listed_tools + else: + return await session.call_tool( + tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10) + ) finally: # Sleep for a while to allow the server to process the background tasks # like pushing traces to the explorer if push_to_explorer: await asyncio.sleep(2) - await client.cleanup()