diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index 7557472..6367460 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -4,7 +4,6 @@ MCP Transport Strategy Pattern Implementation This module defines an abstract base class for MCP transports. """ -import asyncio import json import re import uuid @@ -31,7 +30,7 @@ from gateway.mcp.log import format_errors_in_response from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager -class MCPTransportBase(ABC): +class McpTransportBase(ABC): """ Abstract base class for MCP transport strategies. @@ -53,7 +52,7 @@ class MCPTransportBase(ABC): """ # Update session with request information session = self.session_store.get_session(session_id) - MCPTransportBase.update_session_from_request(session, request_data) + McpTransportBase.update_session_from_request(session, request_data) # Refresh guardrails await session.load_guardrails() @@ -75,10 +74,10 @@ class MCPTransportBase(ABC): """ # Update session with server information session = self.session_store.get_session(session_id) - MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data) + McpTransportBase.update_mcp_server_in_session_metadata(session, response_data) # Intercept and apply guardrails to response - return await MCPTransportBase.intercept_response( + return await McpTransportBase.intercept_response( session_id, self.session_store, response_data ) @@ -87,6 +86,17 @@ class MCPTransportBase(ABC): method = request_data.get(MCP_METHOD) return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS] + @staticmethod + def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict: + return { + "jsonrpc": "2.0", + "id": request_body.get("id"), + "error": { + "code": -32600, + "message": message, + }, + } + async def _intercept_outgoing_request( self, session_id: str, request_data: dict[str, Any] ) -> Tuple[dict[str, Any], bool]: @@ -96,11 +106,11 @@ class MCPTransportBase(ABC): interception_result = request_data is_blocked = False if method == MCP_TOOL_CALL: - interception_result, is_blocked = await MCPTransportBase.hook_tool_call( + interception_result, is_blocked = await McpTransportBase.hook_tool_call( session_id, self.session_store, request_data ) elif method == MCP_LIST_TOOLS: - interception_result, is_blocked = await MCPTransportBase.hook_tool_call( + interception_result, is_blocked = await McpTransportBase.hook_tool_call( session_id=session_id, session_store=self.session_store, request_body={ @@ -152,8 +162,8 @@ class MCPTransportBase(ABC): @staticmethod def update_session_from_request(session: McpSession, request_body: dict) -> None: """Update the MCP client information and request id in the session.""" - MCPTransportBase.update_mcp_client_info_in_session(session, request_body) - MCPTransportBase.update_tool_call_id_in_session(session, request_body) + McpTransportBase.update_mcp_client_info_in_session(session, request_body) + McpTransportBase.update_tool_call_id_in_session(session, request_body) @staticmethod def get_mcp_server_base_url(request: Request) -> str: @@ -164,7 +174,7 @@ class MCPTransportBase(ABC): status_code=400, detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", ) - return MCPTransportBase.convert_localhost_to_docker_host( + return McpTransportBase.convert_localhost_to_docker_host( mcp_server_base_url ).rstrip("/") @@ -233,7 +243,7 @@ class MCPTransportBase(ABC): if ( guardrails_result and guardrails_result.get("errors", []) - and MCPTransportBase.check_if_new_errors( + and McpTransportBase.check_if_new_errors( session_id, session_store, guardrails_result ) ): @@ -243,15 +253,10 @@ class MCPTransportBase(ABC): message=message, guardrails_result=guardrails_result, ) - return { - "jsonrpc": "2.0", - "id": request_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - }, True + return McpTransportBase._create_jsonrpc_error_response( + request_body, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"], + ), True # Push trace to the explorer await session_store.add_message_to_session( @@ -298,22 +303,17 @@ class MCPTransportBase(ABC): if ( guardrails_result and guardrails_result.get("errors", []) - and MCPTransportBase.check_if_new_errors( + and McpTransportBase.check_if_new_errors( session_id, session_store, guardrails_result ) ): is_blocked = True if not is_tools_list: - result = { - "jsonrpc": "2.0", - "id": response_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - } + result = McpTransportBase._create_jsonrpc_error_response( + response_body, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"], + ) else: # Special error response for tools/list result = { @@ -379,7 +379,7 @@ class MCPTransportBase(ABC): ( intercept_response_result, is_blocked, - ) = await MCPTransportBase.hook_tool_call_response( + ) = await McpTransportBase.hook_tool_call_response( session_id=session_id, session_store=session_store, response_body=response_body, @@ -393,7 +393,7 @@ class MCPTransportBase(ABC): ( intercept_response_result, is_blocked, - ) = await MCPTransportBase.hook_tool_call_response( + ) = await McpTransportBase.hook_tool_call_response( session_id=session_id, session_store=session_store, response_body={ @@ -410,9 +410,9 @@ class MCPTransportBase(ABC): return intercept_response_result, is_blocked @abstractmethod - async def initialize_session(self, *args, **kwargs) -> str: + async def initialize_session(self, **kwargs) -> str: """Initialize a session for this transport type.""" @abstractmethod - async def handle_communication(self, *args, **kwargs) -> Any: + async def handle_communication(self, **kwargs) -> Any: """Handle the main communication for this transport.""" diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 3109175..998d5de 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -16,7 +16,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase MCP_SERVER_POST_HEADERS = { "connection", @@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post( raise HTTPException(status_code=400, detail="Session does not exist") request_body = json.loads(await request.body()) - return await SSETransport(session_store).handle_post_request( + return await SseTransport(session_store).handle_post_request( request, session_id, request_body ) @@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream( request: Request, session_store: McpSessionsManager ) -> StreamingResponse: """Integration function for SSE GET route.""" - return await SSETransport(session_store).handle_sse_stream(request) + return await SseTransport(session_store).handle_sse_stream(request) -class SSETransport(MCPTransportBase): +class SseTransport(McpTransportBase): """ Server-Sent Events transport implementation for MCP communication. Handles HTTP-based SSE communication with message queuing. @@ -82,7 +82,6 @@ class SSETransport(MCPTransportBase): async def initialize_session( self, - *args, **kwargs, ) -> str: """Initialize or get existing SSE session.""" @@ -298,7 +297,7 @@ class SSETransport(MCPTransportBase): headers={"X-Proxied-By": "mcp-gateway", **response_headers}, ) - async def handle_communication(self, *args, **kwargs) -> StreamingResponse: + async def handle_communication(self, **kwargs) -> StreamingResponse: """Main communication handler for SSE transport.""" return await self.handle_sse_stream(kwargs.get("request")) diff --git a/gateway/mcp/stdio.py b/gateway/mcp/stdio.py index 4a5fd73..a44fc27 100644 --- a/gateway/mcp/stdio.py +++ b/gateway/mcp/stdio.py @@ -15,7 +15,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpAttributes, McpSessionsManager, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase STATUS_EOF = "eof" STATUS_DATA = "data" @@ -23,7 +23,7 @@ STATUS_WAIT = "wait" mcp_sessions_manager = McpSessionsManager() -class StdioTransport(MCPTransportBase): +class StdioTransport(McpTransportBase): """ STDIO transport implementation for MCP communication. Handles subprocess-based communication with stdin/stdout/stderr. @@ -33,7 +33,7 @@ class StdioTransport(MCPTransportBase): super().__init__(session_store) self.mcp_process: subprocess.Popen = None - async def initialize_session(self, *args, **kwargs) -> str: + async def initialize_session(self, **kwargs) -> str: """Initialize session for stdio transport.""" session_attributes: McpAttributes = kwargs.get("session_attributes") session_id = self.generate_session_id() @@ -53,7 +53,7 @@ class StdioTransport(MCPTransportBase): mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}") return self.mcp_process - async def handle_communication(self, *args, **kwargs) -> None: + async def handle_communication(self, **kwargs) -> None: """Handle stdio communication loop.""" session_id: str = kwargs.get("session_id") mcp_process: subprocess.Popen = kwargs.get("mcp_process") diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index 39a2b01..b066fb7 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -1,7 +1,7 @@ """Gateway service to forward requests to the MCP Streamable HTTP servers""" import json -from typing import Any, Optional, Union +from typing import Any, Optional import httpx from httpx_sse import aconnect_sse @@ -18,7 +18,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase gateway = APIRouter() mcp_sessions_manager = McpSessionsManager() @@ -69,7 +69,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response: async def create_streamable_transport_and_handle_request( request: Request, method: str, session_store: McpSessionsManager -) -> Union[Response, StreamingResponse]: +) -> Response | StreamingResponse: """Integration function for streamable routes.""" streamable_transport = StreamableTransport(session_store) return await streamable_transport.handle_communication( @@ -77,7 +77,7 @@ async def create_streamable_transport_and_handle_request( ) -class StreamableTransport(MCPTransportBase): +class StreamableTransport(McpTransportBase): """ Streamable HTTP transport implementation for MCP communication. Handles HTTP POST/GET/DELETE requests with JSON and streaming responses. @@ -85,7 +85,6 @@ class StreamableTransport(MCPTransportBase): async def initialize_session( self, - *args, **kwargs, ) -> str: """Initialize streamable HTTP session.""" @@ -111,7 +110,7 @@ class StreamableTransport(MCPTransportBase): async def handle_post_request( self, request: Request, request_body: dict[str, Any] - ) -> Union[Response, StreamingResponse]: + ) -> Response | StreamingResponse: """Handle POST request to streamable endpoint.""" session_attributes = McpAttributes.from_request_headers(request.headers) session_id = request.headers.get(MCP_SESSION_ID_HEADER) @@ -222,9 +221,7 @@ class StreamableTransport(MCPTransportBase): print(f"[MCP DELETE] Request error: {str(e)}") raise HTTPException(status_code=500, detail="Request error") from e - async def handle_communication( - self, *args, **kwargs - ) -> Union[Response, StreamingResponse]: + async def handle_communication(self, **kwargs) -> Response | StreamingResponse: """Main communication handler for streamable transport.""" request = kwargs.get("request") method = kwargs.get("method", "POST") @@ -262,7 +259,7 @@ class StreamableTransport(MCPTransportBase): session_id: str, session_attributes: McpAttributes, is_initialization_request: bool, - ) -> Union[Response, StreamingResponse]: + ) -> Response | StreamingResponse: """Forward request to MCP server and handle response.""" async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: try: diff --git a/gateway/mcp/task_utils.py b/gateway/mcp/task_utils.py deleted file mode 100644 index 2ff8b87..0000000 --- a/gateway/mcp/task_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Task utilities for running async functions""" - -import asyncio -import concurrent.futures - -from contextlib import redirect_stdout -from typing import Any - -from gateway.mcp.log import MCP_LOG_FILE - - -def run_task_sync(async_func, *args, **kwargs) -> Any: - """ - Runs an asynchronous function synchronously in a separate - thread with its own event loop. This function blocks the calling - thread until completion or timeout (10 seconds). - - Args: - async_func: The async function to run - *args: Positional arguments to pass to the async function - **kwargs: Keyword arguments to pass to the async function - - Returns: - Any: The return value of the async function - """ - - def run_in_new_loop(): - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete( - async_func( - *args, - **kwargs, - ) - ) - finally: - loop.close() - - with redirect_stdout(MCP_LOG_FILE): - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_new_loop) - return future.result(timeout=10.0)