mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 15:29:43 +02:00
Add blocking and logging related tests for MCP streamable HTTP route.
This commit is contained in:
@@ -89,6 +89,7 @@ async def hook_tool_call(
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
print("[hook_tool_call] Guardrails result:", guardrails_result, flush=True)
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
|
||||
@@ -81,8 +81,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
|
||||
|
||||
# Intercept the request and check for guardrails.
|
||||
if not is_initialization_request:
|
||||
if result := await _intercept_request(session_id, request_body) and result: # noqa: F821 pylint: disable=used-before-assignment
|
||||
return result
|
||||
request_interception_result = await _intercept_request(session_id, request_body)
|
||||
if request_interception_result:
|
||||
return request_interception_result
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
|
||||
@@ -164,7 +164,17 @@ async def test_mcp_with_gateway(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
@pytest.mark.parametrize(
|
||||
"transport",
|
||||
[
|
||||
"stdio",
|
||||
"sse",
|
||||
"streamable-json-stateless",
|
||||
"streamable-json-stateful",
|
||||
"streamable-sse-stateless",
|
||||
"streamable-sse-stateful",
|
||||
],
|
||||
)
|
||||
async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
):
|
||||
@@ -201,7 +211,7 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
else:
|
||||
elif transport == "stdio":
|
||||
result = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
@@ -210,6 +220,16 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
)
|
||||
else:
|
||||
result = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
|
||||
assert result.isError is False
|
||||
assert (
|
||||
@@ -232,6 +252,7 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
metadata = trace["extra_metadata"]
|
||||
assert (
|
||||
metadata["source"] == "mcp"
|
||||
@@ -240,6 +261,19 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
)
|
||||
assert "session_id" in metadata
|
||||
assert "system_user" in metadata
|
||||
if transport == "streamable-json-stateless":
|
||||
assert metadata["server_response_type"] == "json"
|
||||
assert metadata["is_stateless_http_server"] is True
|
||||
elif transport == "streamable-json-stateful":
|
||||
assert metadata["server_response_type"] == "json"
|
||||
assert metadata["is_stateless_http_server"] is False
|
||||
elif transport == "streamable-sse-stateless":
|
||||
assert metadata["server_response_type"] == "sse"
|
||||
assert metadata["is_stateless_http_server"] is True
|
||||
elif transport == "streamable-sse-stateful":
|
||||
assert metadata["server_response_type"] == "sse"
|
||||
assert metadata["is_stateless_http_server"] is False
|
||||
|
||||
assert trace["messages"][2]["role"] == "assistant"
|
||||
assert trace["messages"][2]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
@@ -279,7 +313,17 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
@pytest.mark.parametrize(
|
||||
"transport",
|
||||
[
|
||||
"stdio",
|
||||
"sse",
|
||||
"streamable-json-stateless",
|
||||
"streamable-json-stateful",
|
||||
"streamable-sse-stateless",
|
||||
"streamable-sse-stateful",
|
||||
],
|
||||
)
|
||||
async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
):
|
||||
@@ -312,7 +356,7 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
_get_mcp_sse_server_base_url(), project_name, True
|
||||
),
|
||||
)
|
||||
else:
|
||||
elif transport == "stdio":
|
||||
_ = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
@@ -321,8 +365,9 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
)
|
||||
# If we get here, the tool call was not blocked
|
||||
pytest.fail("Expected McpError to be raised")
|
||||
if not transport.startswith("streamable-"):
|
||||
# If we get here, the tool call was not blocked
|
||||
pytest.fail("Expected McpError to be raised")
|
||||
# The tool call should be blocked by the guardrail
|
||||
# and an error should be raised.
|
||||
except McpError as e:
|
||||
@@ -333,6 +378,25 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
assert "get_last_message_from_user is called" in e.error.message
|
||||
assert e.error.code == -32600
|
||||
|
||||
if transport.startswith("streamable-"):
|
||||
with pytest.raises(ExceptionGroup) as exc_group:
|
||||
_ = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
# Extract the actual HTTPStatusError
|
||||
http_errors = [
|
||||
e
|
||||
for e in exc_group.value.exceptions
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
]
|
||||
assert http_errors[0].response.status_code == 400
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
|
||||
@@ -375,7 +439,17 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
@pytest.mark.parametrize(
|
||||
"transport",
|
||||
[
|
||||
"stdio",
|
||||
"sse",
|
||||
"streamable-json-stateless",
|
||||
"streamable-json-stateful",
|
||||
"streamable-sse-stateless",
|
||||
"streamable-sse-stateful",
|
||||
],
|
||||
)
|
||||
async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
):
|
||||
@@ -416,7 +490,7 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
_get_mcp_sse_server_base_url(), project_name, True
|
||||
),
|
||||
)
|
||||
else:
|
||||
elif transport == "stdio":
|
||||
_ = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
@@ -425,8 +499,9 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
)
|
||||
# If we get here, the tool call was not blocked
|
||||
pytest.fail("Expected McpError to be raised")
|
||||
if not transport.startswith("streamable-"):
|
||||
# If we get here, the tool call was not blocked
|
||||
pytest.fail("Expected McpError to be raised")
|
||||
# The tool call output should be blocked by the guardrail
|
||||
# and an error should be raised.
|
||||
except McpError as e:
|
||||
@@ -437,6 +512,34 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
assert "food in ToolOutput" in e.error.message
|
||||
assert e.error.code == -32600
|
||||
|
||||
if transport.startswith("streamable-"):
|
||||
with pytest.raises(ExceptionGroup) as exc_group:
|
||||
_ = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
if transport.startswith("streamable-json"):
|
||||
# Extract the actual HTTPStatusError
|
||||
http_errors = [
|
||||
e
|
||||
for e in exc_group.value.exceptions
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
]
|
||||
assert http_errors[0].response.status_code == 400
|
||||
else:
|
||||
mcp_error = [e for e in exc_group.value.exceptions][0].exceptions[0]
|
||||
assert (
|
||||
"[Invariant Guardrails] The MCP tool call was blocked for security reasons"
|
||||
in mcp_error.error.message
|
||||
)
|
||||
assert "food in ToolOutput" in mcp_error.error.message
|
||||
assert -32600 == mcp_error.error.code
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
|
||||
@@ -499,7 +602,17 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
@pytest.mark.parametrize(
|
||||
"transport",
|
||||
[
|
||||
"stdio",
|
||||
"sse",
|
||||
"streamable-json-stateless",
|
||||
"streamable-json-stateful",
|
||||
"streamable-sse-stateless",
|
||||
"streamable-sse-stateful",
|
||||
],
|
||||
)
|
||||
async def test_mcp_tool_list_blocking(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
):
|
||||
@@ -524,6 +637,26 @@ async def test_mcp_tool_list_blocking(
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
if transport.startswith("streamable-json"):
|
||||
with pytest.raises(ExceptionGroup) as exc_group:
|
||||
_ = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
# Extract the actual HTTPStatusError
|
||||
http_errors = [
|
||||
e
|
||||
for e in exc_group.value.exceptions
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
]
|
||||
assert http_errors[0].response.status_code == 400
|
||||
return
|
||||
|
||||
# Run the MCP client and make the tools/list call.
|
||||
if transport == "sse":
|
||||
tools_result = await mcp_sse_client_run(
|
||||
@@ -533,7 +666,7 @@ async def test_mcp_tool_list_blocking(
|
||||
tool_args={},
|
||||
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
|
||||
)
|
||||
else:
|
||||
elif transport == "stdio":
|
||||
tools_result = await mcp_stdio_client_run(
|
||||
invariant_gateway_package_whl_file,
|
||||
project_name,
|
||||
@@ -542,7 +675,16 @@ async def test_mcp_tool_list_blocking(
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
)
|
||||
|
||||
else:
|
||||
tools_result = await mcp_streamable_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/streamable",
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
headers=_get_headers(
|
||||
_get_streamable_server_base_url(transport), project_name, True
|
||||
),
|
||||
)
|
||||
assert "blocked_get_last_message_from_user" in str(tools_result), (
|
||||
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
|
||||
+ str(tools_result)
|
||||
|
||||
@@ -1,69 +1,14 @@
|
||||
"""This is a simple example of how to use the MCP client with Streamable HTTP transport."""
|
||||
|
||||
# pylint: disable=E1101
|
||||
# pylint: disable=W0201
|
||||
# pylint: disable=C2801
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Any, Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP Client for interacting with a MCP Streamable HTTP server and processing queries"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self._streams_context = None # Initialize these to None
|
||||
self._session_context = None # so they always exist
|
||||
|
||||
async def connect_to_streamable_server(
|
||||
self, server_url: str, headers: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Connect to an MCP server running with Streamable HTTP transport
|
||||
|
||||
Args:
|
||||
server_url: URL of the MCP server
|
||||
headers: Optional headers to include in the request
|
||||
"""
|
||||
# Store the context managers so they stay alive
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=server_url,
|
||||
headers=headers or {},
|
||||
timeout=timedelta(seconds=5),
|
||||
sse_read_timeout=timedelta(seconds=10),
|
||||
)
|
||||
read_stream, write_stream, session_id = await self._streams_context.__aenter__()
|
||||
self.session_id = session_id
|
||||
|
||||
self._session_context = ClientSession(read_stream, write_stream)
|
||||
self.session: ClientSession = await self._session_context.__aenter__()
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Properly clean up the session and streams"""
|
||||
if self._session_context:
|
||||
await self._session_context.__aexit__(None, None, None)
|
||||
if self._streams_context:
|
||||
await self._streams_context.__aexit__(None, None, None)
|
||||
|
||||
async def process_query(self, tool_name: str, tool_args: dict) -> str:
|
||||
"""Process a query using MCP server"""
|
||||
result = await self.session.call_tool(
|
||||
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def run(
|
||||
gateway_url: str,
|
||||
push_to_explorer: bool,
|
||||
@@ -81,21 +26,27 @@ async def run(
|
||||
tool_args: Arguments for the tool call
|
||||
|
||||
"""
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_to_streamable_server(
|
||||
server_url=gateway_url, headers=headers or {}
|
||||
streams_context = streamablehttp_client(
|
||||
url=gateway_url,
|
||||
headers=headers or {},
|
||||
timeout=timedelta(seconds=5),
|
||||
sse_read_timeout=timedelta(seconds=10),
|
||||
)
|
||||
# list tools
|
||||
listed_tools = await client.session.list_tools()
|
||||
# call tool
|
||||
if tool_name == "tools/list":
|
||||
return listed_tools
|
||||
else:
|
||||
return await client.process_query(tool_name, tool_args)
|
||||
async with streams_context as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
# list tools
|
||||
listed_tools = await session.list_tools()
|
||||
# call tool
|
||||
if tool_name == "tools/list":
|
||||
return listed_tools
|
||||
else:
|
||||
return await session.call_tool(
|
||||
tool_name, tool_args, read_timeout_seconds=timedelta(seconds=10)
|
||||
)
|
||||
finally:
|
||||
# Sleep for a while to allow the server to process the background tasks
|
||||
# like pushing traces to the explorer
|
||||
if push_to_explorer:
|
||||
await asyncio.sleep(2)
|
||||
await client.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user