diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index 6c9c5ff..29d728c 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -4,21 +4,31 @@ MCP Transport Strategy Pattern Implementation This module defines an abstract base class for MCP transports. """ +import asyncio +import json +import re +import uuid from abc import ABC, abstractmethod from typing import Any, Tuple +from fastapi import Request, HTTPException +from gateway.common.guardrails import GuardrailAction +from gateway.integrations.explorer import create_annotations_from_guardrails_errors from gateway.mcp.constants import ( - MCP_METHOD, - MCP_TOOL_CALL, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE, + INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE, + INVARIANT_SESSION_ID_PREFIX, + MCP_CLIENT_INFO, MCP_LIST_TOOLS, + MCP_METHOD, + MCP_PARAMS, + MCP_RESULT, + MCP_SERVER_BASE_URL_HEADER, + MCP_SERVER_INFO, + MCP_TOOL_CALL, ) -from gateway.mcp.mcp_sessions_manager import McpSessionsManager -from gateway.mcp.utils import ( - hook_tool_call, - intercept_response, - update_mcp_server_in_session_metadata, - update_session_from_request, -) +from gateway.mcp.log import format_errors_in_response +from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager class MCPTransportBase(ABC): @@ -43,7 +53,7 @@ class MCPTransportBase(ABC): """ # Update session with request information session = self.session_store.get_session(session_id) - update_session_from_request(session, request_data) + MCPTransportBase.update_session_from_request(session, request_data) # Refresh guardrails await session.load_guardrails() @@ -65,10 +75,12 @@ class MCPTransportBase(ABC): """ # Update session with server information session = self.session_store.get_session(session_id) - update_mcp_server_in_session_metadata(session, response_data) + MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data) # Intercept and apply guardrails to response - return await intercept_response(session_id, self.session_store, response_data) + return await MCPTransportBase.intercept_response( + session_id, self.session_store, response_data + ) def _should_intercept_request(self, request_data: dict[str, Any]) -> bool: """Check if request should be intercepted for guardrails.""" @@ -84,11 +96,11 @@ class MCPTransportBase(ABC): interception_result = request_data is_blocked = False if method == MCP_TOOL_CALL: - interception_result, is_blocked = await hook_tool_call( + interception_result, is_blocked = await MCPTransportBase.hook_tool_call( session_id, self.session_store, request_data ) elif method == MCP_LIST_TOOLS: - interception_result, is_blocked = await hook_tool_call( + interception_result, is_blocked = await MCPTransportBase.hook_tool_call( session_id=session_id, session_store=self.session_store, request_body={ @@ -100,6 +112,305 @@ class MCPTransportBase(ABC): return interception_result, is_blocked + @staticmethod + def generate_session_id() -> str: + """Generate a new session ID.""" + return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex + + @staticmethod + def update_mcp_server_in_session_metadata( + session: McpSession, response_body: dict + ) -> None: + """Update the MCP server information in the session metadata.""" + if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get( + MCP_SERVER_INFO + ): + session.attributes.metadata["mcp_server"] = ( + response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") + ) + + @staticmethod + def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None: + """Updates the tool call ID in the session.""" + if request_body.get(MCP_METHOD) and request_body.get("id"): + session.id_to_method_mapping[request_body.get("id")] = request_body.get( + MCP_METHOD + ) + + @staticmethod + def update_mcp_client_info_in_session( + session: McpSession, request_body: dict + ) -> None: + """Update the MCP client info in the session metadata.""" + if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( + MCP_CLIENT_INFO + ): + session.attributes.metadata["mcp_client"] = ( + request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") + ) + + @staticmethod + def update_session_from_request(session: McpSession, request_body: dict) -> None: + """Update the MCP client information and request id in the session.""" + MCPTransportBase.update_mcp_client_info_in_session(session, request_body) + MCPTransportBase.update_tool_call_id_in_session(session, request_body) + + @staticmethod + def get_mcp_server_base_url(request: Request) -> str: + """Extract the MCP server base URL from the request headers.""" + mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER) + if not mcp_server_base_url: + raise HTTPException( + status_code=400, + detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", + ) + return MCPTransportBase.convert_localhost_to_docker_host( + mcp_server_base_url + ).rstrip("/") + + @staticmethod + def convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: + """Convert localhost or 127.0.0.1 in an address to host.docker.internal.""" + if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url: + modified_address = re.sub( + r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)", + r"\1host.docker.internal\2", + mcp_server_base_url, + ) + return modified_address + return mcp_server_base_url + + @staticmethod + def check_if_new_errors( + session_id: str, + session_store: McpSessionsManager, + guardrails_result: dict, + ) -> bool: + """Checks if there are new errors in the guardrails result.""" + session = session_store.get_session(session_id) + annotations = create_annotations_from_guardrails_errors( + guardrails_result.get("errors", []) + ) + for annotation in annotations: + if annotation not in session.annotations: + return True + return False + + @staticmethod + async def hook_tool_call( + session_id: str, session_store: McpSessionsManager, request_body: dict + ) -> Tuple[dict, bool]: + """ + Hook to process the request JSON before sending it to the MCP server. + + Args: + session_id (str): The session ID associated with the request. + session_store (McpSessionsManager): The session store to manage sessions. + request_body (dict): The request JSON to be processed. + + Returns: + Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean + indicating whether the request was blocked. If the request is blocked, the + dict will contain an error message else it will contain the original request. + """ + tool_call = { + "id": f"call_{request_body.get('id')}", + "type": "function", + "function": { + "name": request_body.get(MCP_PARAMS).get("name"), + "arguments": request_body.get(MCP_PARAMS).get("arguments"), + }, + } + message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} + + # Check for blocking guardrails + session = session_store.get_session(session_id) + guardrails_result = await session.get_guardrails_check_result( + message, action=GuardrailAction.BLOCK + ) + + # If the request is blocked, return error message + if ( + guardrails_result + and guardrails_result.get("errors", []) + and MCPTransportBase.check_if_new_errors( + session_id, session_store, guardrails_result + ) + ): + # Add the trace to the explorer + asyncio.create_task( + session_store.add_message_to_session( + session_id=session_id, + message=message, + guardrails_result=guardrails_result, + ) + ) + return { + "jsonrpc": "2.0", + "id": request_body.get("id"), + "error": { + "code": -32600, + "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE + % guardrails_result["errors"], + }, + }, True + + # Push trace to the explorer + await session_store.add_message_to_session( + session_id, message, guardrails_result + ) + return request_body, False + + @staticmethod + async def hook_tool_call_response( + session_id: str, + session_store: McpSessionsManager, + response_body: dict, + is_tools_list=False, + ) -> Tuple[dict, bool]: + """ + + Hook to process the response JSON after receiving it from the MCP server. + Args: + session_id (str): The session ID associated with the request. + session_store (McpSessionsManager): The session store to manage sessions. + response_body (dict): The response JSON to be processed. + is_tools_list (bool): Flag to indicate if the response is from a tools/list call. + Returns: + Tuple[dict, bool]: A tuple containing the processed response JSON + and a boolean indicating whether the response was blocked. If the response + is blocked, the dict will contain an error message else it will contain the + original response. + """ + is_blocked = False + result = response_body + + message = { + "role": "tool", + "tool_call_id": f"call_{result.get('id')}", + "content": result.get(MCP_RESULT, {}).get("content"), + "error": result.get(MCP_RESULT, {}).get("error"), + } + + session = session_store.get_session(session_id) + guardrails_result = await session.get_guardrails_check_result( + message, action=GuardrailAction.BLOCK + ) + + if ( + guardrails_result + and guardrails_result.get("errors", []) + and MCPTransportBase.check_if_new_errors( + session_id, session_store, guardrails_result + ) + ): + is_blocked = True + + if not is_tools_list: + result = { + "jsonrpc": "2.0", + "id": response_body.get("id"), + "error": { + "code": -32600, + "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE + % guardrails_result["errors"], + }, + } + else: + # Special error response for tools/list + result = { + "jsonrpc": "2.0", + "id": response_body.get("id"), + "result": { + "tools": [ + { + "name": "blocked_" + tool["name"], + "description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE + % format_errors_in_response( + guardrails_result["errors"] + ), + "inputSchema": { + "properties": {}, + "required": [], + "title": "invariant_mcp_server_blockedArguments", + "type": "object", + }, + "annotations": { + "title": "This tool was blocked by security guardrails.", + }, + } + for tool in response_body.get("result", {}).get("tools", []) + ] + }, + } + + # Push trace to the explorer + await session_store.add_message_to_session( + session_id, message, guardrails_result + ) + return result, is_blocked + + @staticmethod + async def intercept_response( + session_id: str, session_store: McpSessionsManager, response_body: dict + ) -> Tuple[dict, bool]: + """ + Intercept the response and check for guardrails. + This function is used to intercept responses and check for guardrails. + If the response is blocked, it returns a message indicating the block + reason with a boolean flag set to True. If the response is not blocked, + it returns the original response with a boolean flag set to False. + + Args: + session_id (str): The session ID associated with the request. + session_store (McpSessionsManager): The session store to manage sessions. + response_body (dict): The response JSON to be processed. + + Returns: + Tuple[dict, bool]: A tuple containing the processed response JSON + and a boolean indicating whether the response was blocked. + """ + session = session_store.get_session(session_id) + method = session.id_to_method_mapping.get(response_body.get("id")) + + intercept_response_result = response_body + is_blocked = False + + # Intercept and potentially block tool call response + if method == MCP_TOOL_CALL: + ( + intercept_response_result, + is_blocked, + ) = await MCPTransportBase.hook_tool_call_response( + session_id=session_id, + session_store=session_store, + response_body=response_body, + ) + # Intercept and potentially block list tool call response + elif method == MCP_LIST_TOOLS: + # Store tools in metadata + tools = response_body.get(MCP_RESULT, {}).get("tools", []) + session_store.get_session(session_id).attributes.metadata["tools"] = tools + + ( + intercept_response_result, + is_blocked, + ) = await MCPTransportBase.hook_tool_call_response( + session_id=session_id, + session_store=session_store, + response_body={ + "jsonrpc": "2.0", + "id": response_body.get("id"), + "result": { + "content": json.dumps(tools), + "tools": tools, + }, + }, + is_tools_list=True, + ) + + return intercept_response_result, is_blocked + @abstractmethod async def initialize_session(self, *args, **kwargs) -> str: """Initialize a session for this transport type.""" diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 3feb47a..7a8fa5e 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -17,9 +17,6 @@ from gateway.mcp.mcp_sessions_manager import ( McpAttributes, ) from gateway.mcp.mcp_transport_base import MCPTransportBase -from gateway.mcp.utils import ( - get_mcp_server_base_url, -) MCP_SERVER_POST_HEADERS = { "connection", @@ -122,7 +119,7 @@ class SSETransport(MCPTransportBase): return Response(content="Accepted", status_code=202) # Forward to MCP server - mcp_server_base_url = get_mcp_server_base_url(request) + mcp_server_base_url = self.get_mcp_server_base_url(request) mcp_server_messages_endpoint = f"{mcp_server_base_url}/messages/?{session_id}" # Filter headers for MCP server @@ -151,7 +148,7 @@ class SSETransport(MCPTransportBase): async def handle_sse_stream(self, request: Request) -> StreamingResponse: """Handle SSE streaming connection.""" - mcp_server_base_url = get_mcp_server_base_url(request) + mcp_server_base_url = self.get_mcp_server_base_url(request) mcp_server_sse_endpoint = f"{mcp_server_base_url}/sse" query_params = dict(request.query_params) diff --git a/gateway/mcp/stdio.py b/gateway/mcp/stdio.py index c86b07c..4a5fd73 100644 --- a/gateway/mcp/stdio.py +++ b/gateway/mcp/stdio.py @@ -16,9 +16,6 @@ from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, ) from gateway.mcp.mcp_transport_base import MCPTransportBase -from gateway.mcp.utils import ( - generate_session_id, -) STATUS_EOF = "eof" STATUS_DATA = "data" @@ -39,7 +36,7 @@ class StdioTransport(MCPTransportBase): async def initialize_session(self, *args, **kwargs) -> str: """Initialize session for stdio transport.""" session_attributes: McpAttributes = kwargs.get("session_attributes") - session_id = generate_session_id() + session_id = self.generate_session_id() await self.session_store.initialize_session(session_id, session_attributes) mcp_log(f"Created stdio session with ID: {session_id}") return session_id diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index 08d0440..aa89c70 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -18,13 +18,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpAttributes, ) from gateway.mcp.mcp_transport_base import MCPTransportBase -from gateway.mcp.utils import ( - generate_session_id, - get_mcp_server_base_url, - update_mcp_client_info_in_session, - update_mcp_server_in_session_metadata, - update_tool_call_id_in_session, -) + gateway = APIRouter() mcp_sessions_manager = McpSessionsManager() @@ -103,7 +97,7 @@ class StreamableTransport(MCPTransportBase): return session_id if is_initialization_request and not session_id: - session_id = generate_session_id() + session_id = self.generate_session_id() if ( session_id @@ -124,7 +118,7 @@ class StreamableTransport(MCPTransportBase): # Handle session initialization if session_id: - update_tool_call_id_in_session( + self.update_tool_call_id_in_session( self.session_store.get_session(session_id), request_body ) elif is_initialization_request: @@ -296,7 +290,7 @@ class StreamableTransport(MCPTransportBase): # Update client info for initialization requests if is_initialization_request: - update_mcp_client_info_in_session( + self.update_mcp_client_info_in_session( self.session_store.get_session(session_id), request_body ) @@ -398,7 +392,7 @@ class StreamableTransport(MCPTransportBase): ) -> None: """Update MCP response info in session metadata.""" session = self.session_store.get_session(session_id) - update_mcp_server_in_session_metadata(session, response_body) + self.update_mcp_server_in_session_metadata(session, response_body) session.attributes.metadata["server_response_type"] = ( "json" if is_json_response else "sse" ) @@ -426,7 +420,7 @@ class StreamableTransport(MCPTransportBase): def _get_mcp_server_endpoint(self, request: Request) -> str: """Get MCP server endpoint URL.""" - return get_mcp_server_base_url(request) + "/mcp/" + return self.get_mcp_server_base_url(request) + "/mcp/" def _is_initialization_request(self, request_data: dict[str, Any]) -> bool: """Check if request is an initialization request.""" diff --git a/gateway/mcp/utils.py b/gateway/mcp/utils.py deleted file mode 100644 index 831ecc3..0000000 --- a/gateway/mcp/utils.py +++ /dev/null @@ -1,304 +0,0 @@ -"""MCP utility functions - Updated to work with transport strategy pattern.""" - -import asyncio -import json -import re -import uuid -from typing import Tuple - -from fastapi import Request, HTTPException - -from gateway.common.guardrails import GuardrailAction -from gateway.integrations.explorer import create_annotations_from_guardrails_errors -from gateway.mcp.constants import ( - INVARIANT_GUARDRAILS_BLOCKED_MESSAGE, - INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE, - INVARIANT_SESSION_ID_PREFIX, - MCP_CLIENT_INFO, - MCP_LIST_TOOLS, - MCP_METHOD, - MCP_PARAMS, - MCP_RESULT, - MCP_SERVER_BASE_URL_HEADER, - MCP_SERVER_INFO, - MCP_TOOL_CALL, -) -from gateway.mcp.log import format_errors_in_response -from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager - - -def generate_session_id() -> str: - """Generate a new session ID.""" - return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex - - -def update_mcp_server_in_session_metadata( - session: McpSession, response_body: dict -) -> None: - """Update the MCP server information in the session metadata.""" - if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get( - MCP_SERVER_INFO - ): - session.attributes.metadata["mcp_server"] = ( - response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") - ) - - -def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None: - """Updates the tool call ID in the session.""" - if request_body.get(MCP_METHOD) and request_body.get("id"): - session.id_to_method_mapping[request_body.get("id")] = request_body.get( - MCP_METHOD - ) - - -def update_mcp_client_info_in_session(session: McpSession, request_body: dict) -> None: - """Update the MCP client info in the session metadata.""" - if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( - MCP_CLIENT_INFO - ): - session.attributes.metadata["mcp_client"] = ( - request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") - ) - - -def update_session_from_request(session: McpSession, request_body: dict) -> None: - """Update the MCP client information and request id in the session.""" - update_mcp_client_info_in_session(session, request_body) - update_tool_call_id_in_session(session, request_body) - - -def get_mcp_server_base_url(request: Request) -> str: - """Extract the MCP server base URL from the request headers.""" - mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER) - if not mcp_server_base_url: - raise HTTPException( - status_code=400, - detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", - ) - return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/") - - -def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: - """Convert localhost or 127.0.0.1 in an address to host.docker.internal.""" - if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url: - modified_address = re.sub( - r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)", - r"\1host.docker.internal\2", - mcp_server_base_url, - ) - return modified_address - return mcp_server_base_url - - -def _check_if_new_errors( - session_id: str, session_store: McpSessionsManager, guardrails_result: dict -) -> bool: - """Checks if there are new errors in the guardrails result.""" - session = session_store.get_session(session_id) - annotations = create_annotations_from_guardrails_errors( - guardrails_result.get("errors", []) - ) - for annotation in annotations: - if annotation not in session.annotations: - return True - return False - - -async def hook_tool_call( - session_id: str, session_store: McpSessionsManager, request_body: dict -) -> Tuple[dict, bool]: - """ - Hook to process the request JSON before sending it to the MCP server. - - Args: - session_id (str): The session ID associated with the request. - session_store (McpSessionsManager): The session store to manage sessions. - request_body (dict): The request JSON to be processed. - - Returns: - Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean - indicating whether the request was blocked. If the request is blocked, the - dict will contain an error message else it will contain the original request. - """ - tool_call = { - "id": f"call_{request_body.get('id')}", - "type": "function", - "function": { - "name": request_body.get(MCP_PARAMS).get("name"), - "arguments": request_body.get(MCP_PARAMS).get("arguments"), - }, - } - message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} - - # Check for blocking guardrails - session = session_store.get_session(session_id) - guardrails_result = await session.get_guardrails_check_result( - message, action=GuardrailAction.BLOCK - ) - - # If the request is blocked, return error message - if ( - guardrails_result - and guardrails_result.get("errors", []) - and _check_if_new_errors(session_id, session_store, guardrails_result) - ): - # Add the trace to the explorer - asyncio.create_task( - session_store.add_message_to_session( - session_id=session_id, - message=message, - guardrails_result=guardrails_result, - ) - ) - return { - "jsonrpc": "2.0", - "id": request_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - }, True - - # Push trace to the explorer - await session_store.add_message_to_session(session_id, message, guardrails_result) - return request_body, False - - -async def hook_tool_call_response( - session_id: str, - session_store: McpSessionsManager, - response_body: dict, - is_tools_list=False, -) -> Tuple[dict, bool]: - """ - - Hook to process the response JSON after receiving it from the MCP server. - Args: - session_id (str): The session ID associated with the request. - session_store (McpSessionsManager): The session store to manage sessions. - response_body (dict): The response JSON to be processed. - is_tools_list (bool): Flag to indicate if the response is from a tools/list call. - Returns: - Tuple[dict, bool]: A tuple containing the processed response JSON - and a boolean indicating whether the response was blocked. If the response - is blocked, the dict will contain an error message else it will contain the - original response. - """ - is_blocked = False - result = response_body - - message = { - "role": "tool", - "tool_call_id": f"call_{result.get('id')}", - "content": result.get(MCP_RESULT, {}).get("content"), - "error": result.get(MCP_RESULT, {}).get("error"), - } - - session = session_store.get_session(session_id) - guardrails_result = await session.get_guardrails_check_result( - message, action=GuardrailAction.BLOCK - ) - - if ( - guardrails_result - and guardrails_result.get("errors", []) - and _check_if_new_errors(session_id, session_store, guardrails_result) - ): - is_blocked = True - - if not is_tools_list: - result = { - "jsonrpc": "2.0", - "id": response_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - } - else: - # Special error response for tools/list - result = { - "jsonrpc": "2.0", - "id": response_body.get("id"), - "result": { - "tools": [ - { - "name": "blocked_" + tool["name"], - "description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE - % format_errors_in_response(guardrails_result["errors"]), - "inputSchema": { - "properties": {}, - "required": [], - "title": "invariant_mcp_server_blockedArguments", - "type": "object", - }, - "annotations": { - "title": "This tool was blocked by security guardrails.", - }, - } - for tool in response_body.get("result", {}).get("tools", []) - ] - }, - } - - # Push trace to the explorer - await session_store.add_message_to_session(session_id, message, guardrails_result) - return result, is_blocked - - -async def intercept_response( - session_id: str, session_store: McpSessionsManager, response_body: dict -) -> Tuple[dict, bool]: - """ - Intercept the response and check for guardrails. - This function is used to intercept responses and check for guardrails. - If the response is blocked, it returns a message indicating the block - reason with a boolean flag set to True. If the response is not blocked, - it returns the original response with a boolean flag set to False. - - Args: - session_id (str): The session ID associated with the request. - session_store (McpSessionsManager): The session store to manage sessions. - response_body (dict): The response JSON to be processed. - - Returns: - Tuple[dict, bool]: A tuple containing the processed response JSON - and a boolean indicating whether the response was blocked. - """ - session = session_store.get_session(session_id) - method = session.id_to_method_mapping.get(response_body.get("id")) - - intercept_response_result = response_body - is_blocked = False - - # Intercept and potentially block tool call response - if method == MCP_TOOL_CALL: - intercept_response_result, is_blocked = await hook_tool_call_response( - session_id=session_id, - session_store=session_store, - response_body=response_body, - ) - # Intercept and potentially block list tool call response - elif method == MCP_LIST_TOOLS: - # Store tools in metadata - tools = response_body.get(MCP_RESULT, {}).get("tools", []) - session_store.get_session(session_id).attributes.metadata["tools"] = tools - - intercept_response_result, is_blocked = await hook_tool_call_response( - session_id=session_id, - session_store=session_store, - response_body={ - "jsonrpc": "2.0", - "id": response_body.get("id"), - "result": { - "content": json.dumps(tools), - "tools": tools, - }, - }, - is_tools_list=True, - ) - - return intercept_response_result, is_blocked