mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 23:34:02 +02:00
support for blocking tools/list in SSE
This commit is contained in:
@@ -17,3 +17,14 @@ sys.stderr = MCP_LOG_FILE
|
||||
def mcp_log(*args, **kwargs) -> None:
|
||||
"""Custom print function to redirect output to log_out."""
|
||||
builtins_print(*args, **kwargs, file=MCP_LOG_FILE, flush=True)
|
||||
|
||||
def format_errors_in_response(errors: list[dict]) -> str:
|
||||
"""Format a list of errors in a response string."""
|
||||
|
||||
def format_error(error: dict) -> str:
|
||||
msg = " ".join(error.get("args", []))
|
||||
msg += " ".join([f"{k}={v}" for k, v in error.get("kwargs", {}).items()])
|
||||
msg += f" ([{error.get('guardrail', {}).get('id', 'unknown-guardrail')}] {error.get('guardrail', {}).get('name', 'unknown guardrail')})"
|
||||
return msg
|
||||
|
||||
return ", ".join([format_error(error) for error in errors])
|
||||
+1
-13
@@ -25,7 +25,7 @@ from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response
|
||||
from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
|
||||
import getpass
|
||||
@@ -403,18 +403,6 @@ async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
return result
|
||||
|
||||
|
||||
def format_errors_in_response(errors: list[dict]) -> str:
|
||||
"""Format a list of errors in a response string."""
|
||||
|
||||
def format_error(error: dict) -> str:
|
||||
msg = " ".join(error.get("args", []))
|
||||
msg += " ".join([f"{k}={v}" for k, v in error.get("kwargs", {}).items()])
|
||||
msg += f" ([{error.get('guardrail', {}).get('id', 'unknown-guardrail')}] {error.get('guardrail', {}).get('name', 'unknown guardrail')})"
|
||||
return msg
|
||||
|
||||
return ", ".join([format_error(error) for error in errors])
|
||||
|
||||
|
||||
async def stream_and_forward_stdout(
|
||||
mcp_process: subprocess.Popen, ctx: McpContext
|
||||
) -> None:
|
||||
|
||||
+66
-23
@@ -14,6 +14,7 @@ from fastapi.responses import StreamingResponse
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
@@ -27,6 +28,7 @@ from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
)
|
||||
from gateway.mcp.log import format_errors_in_response
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
@@ -45,6 +47,7 @@ gateway = APIRouter()
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
|
||||
@gateway.post("/mcp/messages/")
|
||||
@gateway.post("/mcp/sse/messages/")
|
||||
async def mcp_post_gateway(
|
||||
request: Request,
|
||||
@@ -104,8 +107,8 @@ async def mcp_post_gateway(
|
||||
"id": request_json.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
"params": {
|
||||
"name": request_json.get(MCP_METHOD),
|
||||
"arguments": {}
|
||||
"name": MCP_LIST_TOOLS,
|
||||
"arguments": "{}"
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -144,6 +147,7 @@ async def mcp_post_gateway(
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
|
||||
|
||||
@gateway.get("/mcp/sse")
|
||||
async def mcp_get_sse_gateway(
|
||||
request: Request,
|
||||
@@ -151,6 +155,7 @@ async def mcp_get_sse_gateway(
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
mcp_server_base_url = request.headers.get("mcp-server-base-url")
|
||||
if not mcp_server_base_url:
|
||||
print("missing base url", request.headers, flush=True)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
@@ -355,7 +360,7 @@ async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bo
|
||||
return request_json, False
|
||||
|
||||
|
||||
async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict:
|
||||
async def _hook_tool_call_response(session_id: str, response_json: dict, is_tools_list=False) -> dict:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
@@ -366,6 +371,7 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
|
||||
dict: The response JSON is returned if no guardrail is violated
|
||||
else an error dict is returned.
|
||||
"""
|
||||
blocked = False
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
@@ -383,21 +389,50 @@ async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
blocked = True
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
else:
|
||||
# special error response for tools/list tool call
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "blocked_" + tool["name"],
|
||||
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
|
||||
% format_errors_in_response(guardrails_result["errors"]),
|
||||
# no parameters
|
||||
"inputSchema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "invariant_mcp_server_blockedArguments",
|
||||
"type": "object",
|
||||
},
|
||||
"annotations": {
|
||||
"title": "This tool was blocked by security guardrails.",
|
||||
},
|
||||
}
|
||||
for tool in response_json["result"]["tools"]
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result
|
||||
return result, blocked
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
@@ -480,7 +515,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
if method == MCP_TOOL_CALL:
|
||||
hook_tool_call_response = await _hook_tool_call_response(
|
||||
hook_tool_call_response, blocked = await _hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
response_json=response_json,
|
||||
)
|
||||
@@ -488,24 +523,35 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
# hook_tool_call_response is same as response_json if no guardrail is violated.
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
if blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
session_store.get_session(session_id).metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
hook_tool_call_response = await _hook_tool_call_response(
|
||||
hook_tool_call_response, blocked = await _hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
response_json={
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"content": json.dumps(response_json.get(MCP_RESULT).get("tools"))
|
||||
"content": json.dumps(response_json.get(MCP_RESULT).get("tools")),
|
||||
"tools": response_json.get(MCP_RESULT).get("tools"),
|
||||
}
|
||||
}
|
||||
},
|
||||
is_tools_list=True,
|
||||
)
|
||||
# Update the event bytes with hook_tool_call_response.
|
||||
# hook_tool_call_response is same as response_json if no guardrail is violated.
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
@@ -513,9 +559,6 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
if os.environ.get("DEBUG") == "true":
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(
|
||||
f"[MCP SSE] Error processing message: {e}",
|
||||
flush=True,
|
||||
|
||||
Reference in New Issue
Block a user