Move util methods to MCPTransportBase.

This commit is contained in:
Hemang
2025-06-03 17:53:45 +02:00
committed by Hemang Sarkar
parent 7ec13ad852
commit f6ba31ab2e
5 changed files with 334 additions and 339 deletions

View File

@@ -4,21 +4,31 @@ MCP Transport Strategy Pattern Implementation
This module defines an abstract base class for MCP transports.
"""
import asyncio
import json
import re
import uuid
from abc import ABC, abstractmethod
from typing import Any, Tuple
from fastapi import Request, HTTPException
from gateway.common.guardrails import GuardrailAction
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.mcp.constants import (
MCP_METHOD,
MCP_TOOL_CALL,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
INVARIANT_SESSION_ID_PREFIX,
MCP_CLIENT_INFO,
MCP_LIST_TOOLS,
MCP_METHOD,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_BASE_URL_HEADER,
MCP_SERVER_INFO,
MCP_TOOL_CALL,
)
from gateway.mcp.mcp_sessions_manager import McpSessionsManager
from gateway.mcp.utils import (
hook_tool_call,
intercept_response,
update_mcp_server_in_session_metadata,
update_session_from_request,
)
from gateway.mcp.log import format_errors_in_response
from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager
class MCPTransportBase(ABC):
@@ -43,7 +53,7 @@ class MCPTransportBase(ABC):
"""
# Update session with request information
session = self.session_store.get_session(session_id)
update_session_from_request(session, request_data)
MCPTransportBase.update_session_from_request(session, request_data)
# Refresh guardrails
await session.load_guardrails()
@@ -65,10 +75,12 @@ class MCPTransportBase(ABC):
"""
# Update session with server information
session = self.session_store.get_session(session_id)
update_mcp_server_in_session_metadata(session, response_data)
MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data)
# Intercept and apply guardrails to response
return await intercept_response(session_id, self.session_store, response_data)
return await MCPTransportBase.intercept_response(
session_id, self.session_store, response_data
)
def _should_intercept_request(self, request_data: dict[str, Any]) -> bool:
"""Check if request should be intercepted for guardrails."""
@@ -84,11 +96,11 @@ class MCPTransportBase(ABC):
interception_result = request_data
is_blocked = False
if method == MCP_TOOL_CALL:
interception_result, is_blocked = await hook_tool_call(
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
session_id, self.session_store, request_data
)
elif method == MCP_LIST_TOOLS:
interception_result, is_blocked = await hook_tool_call(
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
session_id=session_id,
session_store=self.session_store,
request_body={
@@ -100,6 +112,305 @@ class MCPTransportBase(ABC):
return interception_result, is_blocked
@staticmethod
def generate_session_id() -> str:
"""Generate a new session ID."""
return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex
@staticmethod
def update_mcp_server_in_session_metadata(
session: McpSession, response_body: dict
) -> None:
"""Update the MCP server information in the session metadata."""
if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.attributes.metadata["mcp_server"] = (
response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
@staticmethod
def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None:
"""Updates the tool call ID in the session."""
if request_body.get(MCP_METHOD) and request_body.get("id"):
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
MCP_METHOD
)
@staticmethod
def update_mcp_client_info_in_session(
session: McpSession, request_body: dict
) -> None:
"""Update the MCP client info in the session metadata."""
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
@staticmethod
def update_session_from_request(session: McpSession, request_body: dict) -> None:
"""Update the MCP client information and request id in the session."""
MCPTransportBase.update_mcp_client_info_in_session(session, request_body)
MCPTransportBase.update_tool_call_id_in_session(session, request_body)
@staticmethod
def get_mcp_server_base_url(request: Request) -> str:
"""Extract the MCP server base URL from the request headers."""
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return MCPTransportBase.convert_localhost_to_docker_host(
mcp_server_base_url
).rstrip("/")
@staticmethod
def convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""Convert localhost or 127.0.0.1 in an address to host.docker.internal."""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
@staticmethod
def check_if_new_errors(
session_id: str,
session_store: McpSessionsManager,
guardrails_result: dict,
) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
@staticmethod
async def hook_tool_call(
session_id: str, session_store: McpSessionsManager, request_body: dict
) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
request_body (dict): The request JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
indicating whether the request was blocked. If the request is blocked, the
dict will contain an error message else it will contain the original request.
"""
tool_call = {
"id": f"call_{request_body.get('id')}",
"type": "function",
"function": {
"name": request_body.get(MCP_PARAMS).get("name"),
"arguments": request_body.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return error message
if (
guardrails_result
and guardrails_result.get("errors", [])
and MCPTransportBase.check_if_new_errors(
session_id, session_store, guardrails_result
)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer
await session_store.add_message_to_session(
session_id, message, guardrails_result
)
return request_body, False
@staticmethod
async def hook_tool_call_response(
session_id: str,
session_store: McpSessionsManager,
response_body: dict,
is_tools_list=False,
) -> Tuple[dict, bool]:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
response_body (dict): The response JSON to be processed.
is_tools_list (bool): Flag to indicate if the response is from a tools/list call.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked. If the response
is blocked, the dict will contain an error message else it will contain the
original response.
"""
is_blocked = False
result = response_body
message = {
"role": "tool",
"tool_call_id": f"call_{result.get('id')}",
"content": result.get(MCP_RESULT, {}).get("content"),
"error": result.get(MCP_RESULT, {}).get("error"),
}
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and MCPTransportBase.check_if_new_errors(
session_id, session_store, guardrails_result
)
):
is_blocked = True
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# Special error response for tools/list
result = {
"jsonrpc": "2.0",
"id": response_body.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(
guardrails_result["errors"]
),
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_body.get("result", {}).get("tools", [])
]
},
}
# Push trace to the explorer
await session_store.add_message_to_session(
session_id, message, guardrails_result
)
return result, is_blocked
@staticmethod
async def intercept_response(
session_id: str, session_store: McpSessionsManager, response_body: dict
) -> Tuple[dict, bool]:
"""
Intercept the response and check for guardrails.
This function is used to intercept responses and check for guardrails.
If the response is blocked, it returns a message indicating the block
reason with a boolean flag set to True. If the response is not blocked,
it returns the original response with a boolean flag set to False.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
response_body (dict): The response JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked.
"""
session = session_store.get_session(session_id)
method = session.id_to_method_mapping.get(response_body.get("id"))
intercept_response_result = response_body
is_blocked = False
# Intercept and potentially block tool call response
if method == MCP_TOOL_CALL:
(
intercept_response_result,
is_blocked,
) = await MCPTransportBase.hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
)
# Intercept and potentially block list tool call response
elif method == MCP_LIST_TOOLS:
# Store tools in metadata
tools = response_body.get(MCP_RESULT, {}).get("tools", [])
session_store.get_session(session_id).attributes.metadata["tools"] = tools
(
intercept_response_result,
is_blocked,
) = await MCPTransportBase.hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body={
"jsonrpc": "2.0",
"id": response_body.get("id"),
"result": {
"content": json.dumps(tools),
"tools": tools,
},
},
is_tools_list=True,
)
return intercept_response_result, is_blocked
@abstractmethod
async def initialize_session(self, *args, **kwargs) -> str:
"""Initialize a session for this transport type."""

View File

@@ -17,9 +17,6 @@ from gateway.mcp.mcp_sessions_manager import (
McpAttributes,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.utils import (
get_mcp_server_base_url,
)
MCP_SERVER_POST_HEADERS = {
"connection",
@@ -122,7 +119,7 @@ class SSETransport(MCPTransportBase):
return Response(content="Accepted", status_code=202)
# Forward to MCP server
mcp_server_base_url = get_mcp_server_base_url(request)
mcp_server_base_url = self.get_mcp_server_base_url(request)
mcp_server_messages_endpoint = f"{mcp_server_base_url}/messages/?{session_id}"
# Filter headers for MCP server
@@ -151,7 +148,7 @@ class SSETransport(MCPTransportBase):
async def handle_sse_stream(self, request: Request) -> StreamingResponse:
"""Handle SSE streaming connection."""
mcp_server_base_url = get_mcp_server_base_url(request)
mcp_server_base_url = self.get_mcp_server_base_url(request)
mcp_server_sse_endpoint = f"{mcp_server_base_url}/sse"
query_params = dict(request.query_params)

View File

@@ -16,9 +16,6 @@ from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.utils import (
generate_session_id,
)
STATUS_EOF = "eof"
STATUS_DATA = "data"
@@ -39,7 +36,7 @@ class StdioTransport(MCPTransportBase):
async def initialize_session(self, *args, **kwargs) -> str:
"""Initialize session for stdio transport."""
session_attributes: McpAttributes = kwargs.get("session_attributes")
session_id = generate_session_id()
session_id = self.generate_session_id()
await self.session_store.initialize_session(session_id, session_attributes)
mcp_log(f"Created stdio session with ID: {session_id}")
return session_id

View File

@@ -18,13 +18,7 @@ from gateway.mcp.mcp_sessions_manager import (
McpAttributes,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.utils import (
generate_session_id,
get_mcp_server_base_url,
update_mcp_client_info_in_session,
update_mcp_server_in_session_metadata,
update_tool_call_id_in_session,
)
gateway = APIRouter()
mcp_sessions_manager = McpSessionsManager()
@@ -103,7 +97,7 @@ class StreamableTransport(MCPTransportBase):
return session_id
if is_initialization_request and not session_id:
session_id = generate_session_id()
session_id = self.generate_session_id()
if (
session_id
@@ -124,7 +118,7 @@ class StreamableTransport(MCPTransportBase):
# Handle session initialization
if session_id:
update_tool_call_id_in_session(
self.update_tool_call_id_in_session(
self.session_store.get_session(session_id), request_body
)
elif is_initialization_request:
@@ -296,7 +290,7 @@ class StreamableTransport(MCPTransportBase):
# Update client info for initialization requests
if is_initialization_request:
update_mcp_client_info_in_session(
self.update_mcp_client_info_in_session(
self.session_store.get_session(session_id), request_body
)
@@ -398,7 +392,7 @@ class StreamableTransport(MCPTransportBase):
) -> None:
"""Update MCP response info in session metadata."""
session = self.session_store.get_session(session_id)
update_mcp_server_in_session_metadata(session, response_body)
self.update_mcp_server_in_session_metadata(session, response_body)
session.attributes.metadata["server_response_type"] = (
"json" if is_json_response else "sse"
)
@@ -426,7 +420,7 @@ class StreamableTransport(MCPTransportBase):
def _get_mcp_server_endpoint(self, request: Request) -> str:
"""Get MCP server endpoint URL."""
return get_mcp_server_base_url(request) + "/mcp/"
return self.get_mcp_server_base_url(request) + "/mcp/"
def _is_initialization_request(self, request_data: dict[str, Any]) -> bool:
"""Check if request is an initialization request."""

View File

@@ -1,304 +0,0 @@
"""MCP utility functions - Updated to work with transport strategy pattern."""
import asyncio
import json
import re
import uuid
from typing import Tuple
from fastapi import Request, HTTPException
from gateway.common.guardrails import GuardrailAction
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.mcp.constants import (
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
INVARIANT_SESSION_ID_PREFIX,
MCP_CLIENT_INFO,
MCP_LIST_TOOLS,
MCP_METHOD,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_BASE_URL_HEADER,
MCP_SERVER_INFO,
MCP_TOOL_CALL,
)
from gateway.mcp.log import format_errors_in_response
from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager
def generate_session_id() -> str:
"""Generate a new session ID."""
return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex
def update_mcp_server_in_session_metadata(
session: McpSession, response_body: dict
) -> None:
"""Update the MCP server information in the session metadata."""
if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.attributes.metadata["mcp_server"] = (
response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None:
"""Updates the tool call ID in the session."""
if request_body.get(MCP_METHOD) and request_body.get("id"):
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
MCP_METHOD
)
def update_mcp_client_info_in_session(session: McpSession, request_body: dict) -> None:
"""Update the MCP client info in the session metadata."""
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
def update_session_from_request(session: McpSession, request_body: dict) -> None:
"""Update the MCP client information and request id in the session."""
update_mcp_client_info_in_session(session, request_body)
update_tool_call_id_in_session(session, request_body)
def get_mcp_server_base_url(request: Request) -> str:
"""Extract the MCP server base URL from the request headers."""
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/")
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""Convert localhost or 127.0.0.1 in an address to host.docker.internal."""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
def _check_if_new_errors(
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def hook_tool_call(
session_id: str, session_store: McpSessionsManager, request_body: dict
) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
request_body (dict): The request JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
indicating whether the request was blocked. If the request is blocked, the
dict will contain an error message else it will contain the original request.
"""
tool_call = {
"id": f"call_{request_body.get('id')}",
"type": "function",
"function": {
"name": request_body.get(MCP_PARAMS).get("name"),
"arguments": request_body.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return error message
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, session_store, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer
await session_store.add_message_to_session(session_id, message, guardrails_result)
return request_body, False
async def hook_tool_call_response(
session_id: str,
session_store: McpSessionsManager,
response_body: dict,
is_tools_list=False,
) -> Tuple[dict, bool]:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
response_body (dict): The response JSON to be processed.
is_tools_list (bool): Flag to indicate if the response is from a tools/list call.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked. If the response
is blocked, the dict will contain an error message else it will contain the
original response.
"""
is_blocked = False
result = response_body
message = {
"role": "tool",
"tool_call_id": f"call_{result.get('id')}",
"content": result.get(MCP_RESULT, {}).get("content"),
"error": result.get(MCP_RESULT, {}).get("error"),
}
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, session_store, guardrails_result)
):
is_blocked = True
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# Special error response for tools/list
result = {
"jsonrpc": "2.0",
"id": response_body.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrails_result["errors"]),
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_body.get("result", {}).get("tools", [])
]
},
}
# Push trace to the explorer
await session_store.add_message_to_session(session_id, message, guardrails_result)
return result, is_blocked
async def intercept_response(
session_id: str, session_store: McpSessionsManager, response_body: dict
) -> Tuple[dict, bool]:
"""
Intercept the response and check for guardrails.
This function is used to intercept responses and check for guardrails.
If the response is blocked, it returns a message indicating the block
reason with a boolean flag set to True. If the response is not blocked,
it returns the original response with a boolean flag set to False.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
response_body (dict): The response JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked.
"""
session = session_store.get_session(session_id)
method = session.id_to_method_mapping.get(response_body.get("id"))
intercept_response_result = response_body
is_blocked = False
# Intercept and potentially block tool call response
if method == MCP_TOOL_CALL:
intercept_response_result, is_blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
)
# Intercept and potentially block list tool call response
elif method == MCP_LIST_TOOLS:
# Store tools in metadata
tools = response_body.get(MCP_RESULT, {}).get("tools", [])
session_store.get_session(session_id).attributes.metadata["tools"] = tools
intercept_response_result, is_blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body={
"jsonrpc": "2.0",
"id": response_body.get("id"),
"result": {
"content": json.dumps(tools),
"tools": tools,
},
},
is_tools_list=True,
)
return intercept_response_result, is_blocked