Move hook_tool_call and hook_tool_call_response to mcp_utils.py so that it can be used by both SSE and Streamable implementations.

This commit is contained in:
Hemang
2025-05-27 14:53:19 +02:00
committed by Hemang Sarkar
parent 6e61a76168
commit 34979ed18d
3 changed files with 211 additions and 323 deletions

View File

@@ -1,10 +1,24 @@
"""MCP utility functions."""
import asyncio
import re
from fastapi import Request, HTTPException
from typing import Tuple
from gateway.common.constants import MCP_SERVER_BASE_URL_HEADER
from fastapi import Request, HTTPException
from gateway.common.constants import (
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
MCP_SERVER_BASE_URL_HEADER,
MCP_PARAMS,
MCP_RESULT,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
)
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.mcp.log import format_errors_in_response
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
@@ -49,3 +63,150 @@ def get_mcp_server_base_url(request: Request) -> str:
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/")
async def hook_tool_call(
session_id: str, session_store: McpSessionsManager, request_body: dict
) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
request_body (dict): The request JSON to be processed.
"""
tool_call = {
"id": f"call_{request_body.get('id')}",
"type": "function",
"function": {
"name": request_body.get(MCP_PARAMS).get("name"),
"arguments": request_body.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails - this blocks until completion
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrails_result
and guardrails_result.get("errors", [])
and check_if_new_errors(session_id, session_store, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer
await session_store.add_message_to_session(session_id, message, guardrails_result)
return request_body, False
def check_if_new_errors(
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def hook_tool_call_response(
session_id: str,
session_store: McpSessionsManager,
response_json: dict,
is_tools_list=False,
) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
response_json (dict): The response JSON to be processed.
Returns:
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
blocked = False
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
"content": response_json.get(MCP_RESULT).get("content"),
"error": response_json.get(MCP_RESULT).get("error"),
}
result = response_json
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and check_if_new_errors(session_id, session_store, guardrails_result)
):
blocked = True
# If the request is blocked, return a message indicating the block reason
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrails_result["errors"]),
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_json["result"]["tools"]
]
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result, blocked

View File

@@ -12,8 +12,6 @@ from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
@@ -23,14 +21,15 @@ from gateway.common.constants import (
MCP_CLIENT_INFO,
UTF_8,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
)
from gateway.common.mcp_utils import get_mcp_server_base_url
from gateway.mcp.log import format_errors_in_response
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.common.mcp_utils import (
get_mcp_server_base_url,
hook_tool_call,
hook_tool_call_response,
)
MCP_SERVER_POST_HEADERS = {
"connection",
@@ -71,35 +70,38 @@ async def mcp_post_sse_gateway(
get_mcp_server_base_url(request) + "/messages/?" + session_id
)
request_body_bytes = await request.body()
request_json = json.loads(request_body_bytes)
request_body = json.loads(request_body_bytes)
session = session_store.get_session(session_id)
if request_json.get(MCP_METHOD) and request_json.get("id"):
session.id_to_method_mapping[request_json.get("id")] = request_json.get(
if request_body.get(MCP_METHOD) and request_body.get("id"):
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
MCP_METHOD
)
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.metadata["mcp_client"] = (
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
if request_json.get(MCP_METHOD) == MCP_TOOL_CALL:
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
# Intercept and potentially block the request
hook_tool_call_result, is_blocked = await _hook_tool_call(
session_id=session_id, request_json=request_json
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id=session_id,
session_store=session_store,
request_body=request_body,
)
if is_blocked:
# Add the error message to the session.
# The error message is sent back to the client using the SSE stream.
await session.add_pending_error_message(hook_tool_call_result)
return Response(content="Accepted", status_code=202)
elif request_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
# Intercept and potentially block the request
hook_tool_call_result, is_blocked = await _hook_tool_call(
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id=session_id,
request_json={
"id": request_json.get("id"),
session_store=session_store,
request_body={
"id": request_body.get("id"),
"method": MCP_LIST_TOOLS,
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
},
@@ -119,7 +121,7 @@ async def mcp_post_sse_gateway(
for k, v in request.headers.items()
if k.lower() in MCP_SERVER_POST_HEADERS
},
json=request_json,
json=request_body,
params=query_params,
)
return Response(
@@ -293,135 +295,6 @@ async def mcp_get_sse_gateway(
)
async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
request_json (dict): The request JSON to be processed.
"""
tool_call = {
"id": f"call_{request_json.get('id')}",
"type": "function",
"function": {
"name": request_json.get(MCP_PARAMS).get("name"),
"arguments": request_json.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails - this blocks until completion
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer
await session_store.add_message_to_session(session_id, message, guardrails_result)
return request_json, False
async def _hook_tool_call_response(
session_id: str, response_json: dict, is_tools_list=False
) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
response_json (dict): The response JSON to be processed.
Returns:
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
blocked = False
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
"content": response_json.get(MCP_RESULT).get("content"),
"error": response_json.get(MCP_RESULT).get("error"),
}
result = response_json
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
blocked = True
# If the request is blocked, return a message indicating the block reason.
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrails_result["errors"]),
# no parameters
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_json["result"]["tools"]
]
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result, blocked
async def _handle_endpoint_event(
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
) -> Tuple[bytes, str]:
@@ -476,8 +349,9 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
method = session.id_to_method_mapping.get(response_json.get("id"))
if method == MCP_TOOL_CALL:
hook_tool_call_response, blocked = await _hook_tool_call_response(
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json=response_json,
)
# Update the event bytes with hook_tool_call_response.
@@ -485,8 +359,8 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
UTF_8
event_bytes = (
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
)
elif method == MCP_LIST_TOOLS:
# store tools in metadata
@@ -494,8 +368,9 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
hook_tool_call_response, blocked = await _hook_tool_call_response(
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json={
"id": response_json.get("id"),
"result": {
@@ -512,8 +387,8 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
UTF_8
event_bytes = (
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
)
except json.JSONDecodeError as e:
@@ -529,18 +404,6 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
return event_bytes
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def _check_for_pending_error_messages(
session_id: str, pending_error_messages_queue: asyncio.Queue
):

View File

@@ -1,6 +1,5 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
import asyncio
import json
import uuid
@@ -13,8 +12,6 @@ from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
INVARIANT_SESSION_ID_PREFIX,
MCP_CLIENT_INFO,
MCP_LIST_TOOLS,
@@ -25,14 +22,15 @@ from gateway.common.constants import (
MCP_TOOL_CALL,
UTF_8,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
)
from gateway.common.mcp_utils import get_mcp_server_base_url
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.mcp.log import format_errors_in_response
from gateway.common.mcp_utils import (
get_mcp_server_base_url,
hook_tool_call,
hook_tool_call_response,
)
gateway = APIRouter()
session_store = McpSessionsManager()
@@ -448,146 +446,6 @@ async def _handle_mcp_streaming_response(
)
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def _hook_tool_call(session_id: str, request_body: dict) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
request_body (dict): The request JSON to be processed.
"""
tool_call = {
"id": f"call_{request_body.get('id')}",
"type": "function",
"function": {
"name": request_body.get(MCP_PARAMS).get("name"),
"arguments": request_body.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails - this blocks until completion
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer
await session_store.add_message_to_session(session_id, message, guardrails_result)
return request_body, False
async def _hook_tool_call_response(
session_id: str, response_json: dict, is_tools_list=False
) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
response_json (dict): The response JSON to be processed.
Returns:
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
blocked = False
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
"content": response_json.get(MCP_RESULT).get("content"),
"error": response_json.get(MCP_RESULT).get("error"),
}
result = response_json
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
blocked = True
# If the request is blocked, return a message indicating the block reason
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrails_result["errors"]),
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_json["result"]["tools"]
]
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result, blocked
async def _intercept_request(session_id: str, request_body: dict) -> Response | None:
"""
Intercept the request and check for guardrails.
@@ -595,8 +453,10 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response |
If the request is blocked, it returns a message indicating the block reason.
"""
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
hook_tool_call_result, is_blocked = await _hook_tool_call(
session_id=session_id, request_body=request_body
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id=session_id,
session_store=session_store,
request_body=request_body,
)
if is_blocked:
return Response(
@@ -605,8 +465,9 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response |
media_type="application/json",
)
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
hook_tool_call_result, is_blocked = await _hook_tool_call(
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id=session_id,
session_store=session_store,
request_body={
"id": request_body.get("id"),
"method": MCP_LIST_TOOLS,
@@ -636,10 +497,12 @@ async def _intercept_response(
method = session.id_to_method_mapping.get(response_json.get("id"))
# Intercept and potentially block tool call response
if method == MCP_TOOL_CALL:
hook_tool_call_response, blocked = await _hook_tool_call_response(
session_id=session_id, response_json=response_json
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json=response_json,
)
return hook_tool_call_response, blocked
return result, blocked
# Intercept and potentially block list tool call response
elif method == MCP_LIST_TOOLS:
# store tools in metadata
@@ -647,8 +510,9 @@ async def _intercept_response(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
hook_tool_call_response, blocked = await _hook_tool_call_response(
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json={
"id": response_json.get("id"),
"result": {
@@ -658,5 +522,5 @@ async def _intercept_response(
},
is_tools_list=True,
)
return hook_tool_call_response, blocked
return result, blocked
return response_json, False