diff --git a/gateway/mcp/log.py b/gateway/mcp/log.py index 667119d..7287513 100644 --- a/gateway/mcp/log.py +++ b/gateway/mcp/log.py @@ -17,3 +17,14 @@ sys.stderr = MCP_LOG_FILE def mcp_log(*args, **kwargs) -> None: """Custom print function to redirect output to log_out.""" builtins_print(*args, **kwargs, file=MCP_LOG_FILE, flush=True) + +def format_errors_in_response(errors: list[dict]) -> str: + """Format a list of errors in a response string.""" + + def format_error(error: dict) -> str: + msg = " ".join(error.get("args", [])) + msg += " ".join([f"{k}={v}" for k, v in error.get("kwargs", {}).items()]) + msg += f" ([{error.get('guardrail', {}).get('id', 'unknown-guardrail')}] {error.get('guardrail', {}).get('name', 'unknown guardrail')})" + return msg + + return ", ".join([format_error(error) for error in errors]) \ No newline at end of file diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py index ffb4ca7..73dadce 100644 --- a/gateway/mcp/mcp.py +++ b/gateway/mcp/mcp.py @@ -25,7 +25,7 @@ from gateway.common.guardrails import GuardrailAction from gateway.common.request_context import RequestContext from gateway.integrations.explorer import create_annotations_from_guardrails_errors from gateway.integrations.guardrails import check_guardrails -from gateway.mcp.log import mcp_log, MCP_LOG_FILE +from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response from gateway.mcp.mcp_context import McpContext from gateway.mcp.task_utils import run_task_in_background, run_task_sync import getpass @@ -403,18 +403,6 @@ async def hook_tool_result(ctx: McpContext, result: dict) -> dict: return result -def format_errors_in_response(errors: list[dict]) -> str: - """Format a list of errors in a response string.""" - - def format_error(error: dict) -> str: - msg = " ".join(error.get("args", [])) - msg += " ".join([f"{k}={v}" for k, v in error.get("kwargs", {}).items()]) - msg += f" ([{error.get('guardrail', {}).get('id', 'unknown-guardrail')}] {error.get('guardrail', {}).get('name', 'unknown guardrail')})" - return msg - - return ", ".join([format_error(error) for error in errors]) - - async def stream_and_forward_stdout( mcp_process: subprocess.Popen, ctx: McpContext ) -> None: diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index d70b220..4928a12 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse from gateway.common.constants import ( CLIENT_TIMEOUT, INVARIANT_GUARDRAILS_BLOCKED_MESSAGE, + INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE, MCP_METHOD, MCP_TOOL_CALL, MCP_LIST_TOOLS, @@ -27,6 +28,7 @@ from gateway.common.mcp_sessions_manager import ( McpSessionsManager, SseHeaderAttributes, ) +from gateway.mcp.log import format_errors_in_response from gateway.integrations.explorer import create_annotations_from_guardrails_errors MCP_SERVER_POST_HEADERS = { @@ -45,6 +47,7 @@ gateway = APIRouter() session_store = McpSessionsManager() +@gateway.post("/mcp/messages/") @gateway.post("/mcp/sse/messages/") async def mcp_post_gateway( request: Request, @@ -104,8 +107,8 @@ async def mcp_post_gateway( "id": request_json.get("id"), "method": MCP_LIST_TOOLS, "params": { - "name": request_json.get(MCP_METHOD), - "arguments": {} + "name": MCP_LIST_TOOLS, + "arguments": "{}" }, } ) @@ -144,6 +147,7 @@ async def mcp_post_gateway( raise HTTPException(status_code=500, detail="Unexpected error") from e + @gateway.get("/mcp/sse") async def mcp_get_sse_gateway( request: Request, @@ -151,6 +155,7 @@ async def mcp_get_sse_gateway( """Proxy calls to the MCP Server tools""" mcp_server_base_url = request.headers.get("mcp-server-base-url") if not mcp_server_base_url: + print("missing base url", request.headers, flush=True) raise HTTPException( status_code=400, detail="Missing 'mcp-server-base-url' header", @@ -355,7 +360,7 @@ async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bo return request_json, False -async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict: +async def _hook_tool_call_response(session_id: str, response_json: dict, is_tools_list=False) -> dict: """ Hook to process the response JSON after receiving it from the MCP server. @@ -366,6 +371,7 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict dict: The response JSON is returned if no guardrail is violated else an error dict is returned. """ + blocked = False message = { "role": "tool", "tool_call_id": f"call_{response_json.get('id')}", @@ -383,21 +389,50 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict and guardrails_result.get("errors", []) and _check_if_new_errors(session_id, guardrails_result) ): + blocked = True # If the request is blocked, return a message indicating the block reason. - result = { - "jsonrpc": "2.0", - "id": response_json.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - } + if not is_tools_list: + result = { + "jsonrpc": "2.0", + "id": response_json.get("id"), + "error": { + "code": -32600, + "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE + % guardrails_result["errors"], + }, + } + else: + # special error response for tools/list tool call + result = { + "jsonrpc": "2.0", + "id": response_json.get("id"), + "result": { + "tools": [ + { + "name": "blocked_" + tool["name"], + "description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE + % format_errors_in_response(guardrails_result["errors"]), + # no parameters + "inputSchema": { + "properties": {}, + "required": [], + "title": "invariant_mcp_server_blockedArguments", + "type": "object", + }, + "annotations": { + "title": "This tool was blocked by security guardrails.", + }, + } + for tool in response_json["result"]["tools"] + ] + } + } + # Push trace to the explorer - don't block on its response asyncio.create_task( session_store.add_message_to_session(session_id, message, guardrails_result) ) - return result + return result, blocked def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: @@ -480,7 +515,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: method = session.id_to_method_mapping.get(response_json.get("id")) if method == MCP_TOOL_CALL: - hook_tool_call_response = await _hook_tool_call_response( + hook_tool_call_response, blocked = await _hook_tool_call_response( session_id=session_id, response_json=response_json, ) @@ -488,24 +523,35 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: # hook_tool_call_response is same as response_json if no guardrail is violated. # If guardrail is violated, it contains the error message. # pylint: disable=line-too-long - event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode( - "utf-8" - ) + if blocked: + event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode( + "utf-8" + ) elif method == MCP_LIST_TOOLS: # store tools in metadata session_store.get_session(session_id).metadata["tools"] = response_json.get( MCP_RESULT ).get("tools") # store tools/list tool call in trace - hook_tool_call_response = await _hook_tool_call_response( + hook_tool_call_response, blocked = await _hook_tool_call_response( session_id=session_id, response_json={ "id": response_json.get("id"), "result": { - "content": json.dumps(response_json.get(MCP_RESULT).get("tools")) + "content": json.dumps(response_json.get(MCP_RESULT).get("tools")), + "tools": response_json.get(MCP_RESULT).get("tools"), } - } + }, + is_tools_list=True, ) + # Update the event bytes with hook_tool_call_response. + # hook_tool_call_response is same as response_json if no guardrail is violated. + # If guardrail is violated, it contains the error message. + # pylint: disable=line-too-long + if blocked: + event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode( + "utf-8" + ) except json.JSONDecodeError as e: print( @@ -513,9 +559,6 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: flush=True, ) except Exception as e: # pylint: disable=broad-except - if os.environ.get("DEBUG") == "true": - import traceback - traceback.print_exc() print( f"[MCP SSE] Error processing message: {e}", flush=True,