support for blocking tools/list in SSE

This commit is contained in:
Luca Beurer-Kellner
2025-05-16 21:33:16 +02:00
parent c16a12fdec
commit 6363c6eb97
3 changed files with 78 additions and 36 deletions
+11
View File
@@ -17,3 +17,14 @@ sys.stderr = MCP_LOG_FILE
def mcp_log(*args, **kwargs) -> None:
"""Custom print function to redirect output to log_out."""
builtins_print(*args, **kwargs, file=MCP_LOG_FILE, flush=True)
def format_errors_in_response(errors: list[dict]) -> str:
"""Format a list of errors in a response string."""
def format_error(error: dict) -> str:
msg = " ".join(error.get("args", []))
msg += " ".join([f"{k}={v}" for k, v in error.get("kwargs", {}).items()])
msg += f" ([{error.get('guardrail', {}).get('id', 'unknown-guardrail')}] {error.get('guardrail', {}).get('name', 'unknown guardrail')})"
return msg
return ", ".join([format_error(error) for error in errors])
+1 -13
View File
@@ -25,7 +25,7 @@ from gateway.common.guardrails import GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.integrations.guardrails import check_guardrails
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response
from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
import getpass
@@ -403,18 +403,6 @@ async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
return result
def format_errors_in_response(errors: list[dict]) -> str:
"""Format a list of errors in a response string."""
def format_error(error: dict) -> str:
msg = " ".join(error.get("args", []))
msg += " ".join([f"{k}={v}" for k, v in error.get("kwargs", {}).items()])
msg += f" ([{error.get('guardrail', {}).get('id', 'unknown-guardrail')}] {error.get('guardrail', {}).get('name', 'unknown guardrail')})"
return msg
return ", ".join([format_error(error) for error in errors])
async def stream_and_forward_stdout(
mcp_process: subprocess.Popen, ctx: McpContext
) -> None:
+66 -23
View File
@@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
@@ -27,6 +28,7 @@ from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
)
from gateway.mcp.log import format_errors_in_response
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
MCP_SERVER_POST_HEADERS = {
@@ -45,6 +47,7 @@ gateway = APIRouter()
session_store = McpSessionsManager()
@gateway.post("/mcp/messages/")
@gateway.post("/mcp/sse/messages/")
async def mcp_post_gateway(
request: Request,
@@ -104,8 +107,8 @@ async def mcp_post_gateway(
"id": request_json.get("id"),
"method": MCP_LIST_TOOLS,
"params": {
"name": request_json.get(MCP_METHOD),
"arguments": {}
"name": MCP_LIST_TOOLS,
"arguments": "{}"
},
}
)
@@ -144,6 +147,7 @@ async def mcp_post_gateway(
raise HTTPException(status_code=500, detail="Unexpected error") from e
@gateway.get("/mcp/sse")
async def mcp_get_sse_gateway(
request: Request,
@@ -151,6 +155,7 @@ async def mcp_get_sse_gateway(
"""Proxy calls to the MCP Server tools"""
mcp_server_base_url = request.headers.get("mcp-server-base-url")
if not mcp_server_base_url:
print("missing base url", request.headers, flush=True)
raise HTTPException(
status_code=400,
detail="Missing 'mcp-server-base-url' header",
@@ -355,7 +360,7 @@ async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bo
return request_json, False
async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict:
async def _hook_tool_call_response(session_id: str, response_json: dict, is_tools_list=False) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
@@ -366,6 +371,7 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
blocked = False
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
@@ -383,21 +389,50 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
blocked = True
# If the request is blocked, return a message indicating the block reason.
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrails_result["errors"]),
# no parameters
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_json["result"]["tools"]
]
}
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result
return result, blocked
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
@@ -480,7 +515,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
method = session.id_to_method_mapping.get(response_json.get("id"))
if method == MCP_TOOL_CALL:
hook_tool_call_response = await _hook_tool_call_response(
hook_tool_call_response, blocked = await _hook_tool_call_response(
session_id=session_id,
response_json=response_json,
)
@@ -488,24 +523,35 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
)
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
)
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
hook_tool_call_response = await _hook_tool_call_response(
hook_tool_call_response, blocked = await _hook_tool_call_response(
session_id=session_id,
response_json={
"id": response_json.get("id"),
"result": {
"content": json.dumps(response_json.get(MCP_RESULT).get("tools"))
"content": json.dumps(response_json.get(MCP_RESULT).get("tools")),
"tools": response_json.get(MCP_RESULT).get("tools"),
}
}
},
is_tools_list=True,
)
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
)
except json.JSONDecodeError as e:
print(
@@ -513,9 +559,6 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
flush=True,
)
except Exception as e: # pylint: disable=broad-except
if os.environ.get("DEBUG") == "true":
import traceback
traceback.print_exc()
print(
f"[MCP SSE] Error processing message: {e}",
flush=True,