From 7efd15e2a9a4b33ddf13a2d75f46a039751c0c2b Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 3 Jun 2025 13:59:34 +0200 Subject: [PATCH] Move MCP related routes to the MCP directory and introduce the MCPTransportBase class. --- gateway/__main__.py | 4 +- gateway/mcp/mcp.py | 317 ------------------ .../{common => mcp}/mcp_sessions_manager.py | 0 gateway/mcp/mcp_transport_base.py | 116 +++++++ gateway/{routes/mcp_sse.py => mcp/sse.py} | 4 +- gateway/mcp/stdio.py | 314 +++++++++++++++++ .../mcp_streamable.py => mcp/streamable.py} | 4 +- gateway/{common/mcp_utils.py => mcp/utils.py} | 2 +- gateway/serve.py | 4 +- 9 files changed, 439 insertions(+), 326 deletions(-) delete mode 100644 gateway/mcp/mcp.py rename gateway/{common => mcp}/mcp_sessions_manager.py (100%) create mode 100644 gateway/mcp/mcp_transport_base.py rename gateway/{routes/mcp_sse.py => mcp/sse.py} (99%) create mode 100644 gateway/mcp/stdio.py rename gateway/{routes/mcp_streamable.py => mcp/streamable.py} (99%) rename gateway/{common/mcp_utils.py => mcp/utils.py} (99%) diff --git a/gateway/__main__.py b/gateway/__main__.py index 1aad784..de26a66 100644 --- a/gateway/__main__.py +++ b/gateway/__main__.py @@ -9,7 +9,7 @@ import time from typing import Optional -from gateway.mcp import mcp +from gateway.mcp import stdio as mcp_stdio from gateway.mcp.log import mcp_log @@ -235,7 +235,7 @@ def main(): sys.exit(1) if verb == "mcp": - return asyncio.run(mcp.execute(sys.argv[2:])) + return asyncio.run(mcp_stdio.execute(sys.argv[2:])) if verb == "server": if len(sys.argv) < 3: diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py deleted file mode 100644 index 4f262c8..0000000 --- a/gateway/mcp/mcp.py +++ /dev/null @@ -1,317 +0,0 @@ -"""Gateway for MCP (Model Context Protocol) integration with Invariant.""" - -import asyncio -import json -import os -import platform -import select -import subprocess -import sys - -from gateway.common.constants import ( - MCP_METHOD, - MCP_TOOL_CALL, - MCP_LIST_TOOLS, - UTF_8, -) -from gateway.common.mcp_sessions_manager import ( - McpAttributes, - McpSessionsManager, -) -from gateway.common.mcp_utils import ( - generate_session_id, - hook_tool_call, - intercept_response, - update_mcp_server_in_session_metadata, - update_session_from_request, -) -from gateway.mcp.log import mcp_log, MCP_LOG_FILE - -STATUS_EOF = "eof" -STATUS_DATA = "data" -STATUS_WAIT = "wait" -session_store = McpSessionsManager() - - -def write_as_utf8_bytes(data: dict) -> bytes: - """Serializes dict to bytes using UTF-8 encoding.""" - return json.dumps(data).encode(UTF_8) + b"\n" - - -async def stream_and_forward_stdout( - session_id: str, mcp_process: subprocess.Popen -) -> None: - """Read from the mcp_process stdout, apply guardrails and forward to sys.stdout""" - loop = asyncio.get_event_loop() - while True: - if mcp_process.poll() is not None: - mcp_log(f"[ERROR] MCP process terminated with code: {mcp_process.poll()}") - break - - line = await loop.run_in_executor(None, mcp_process.stdout.readline) - if not line: - break - - try: - # Process complete JSON lines - decoded_line = line.decode(UTF_8).strip() - if not decoded_line: - continue - session = session_store.get_session(session_id) - if session.attributes.verbose: - mcp_log(f"[INFO] server -> client: {decoded_line}") - response_body = json.loads(decoded_line) - update_mcp_server_in_session_metadata(session, response_body) - - intercept_response_result, _ = await intercept_response( - session_id, session_store, response_body - ) - # Write and flush immediately - sys.stdout.buffer.write(write_as_utf8_bytes(intercept_response_result)) - sys.stdout.buffer.flush() - except Exception as e: # pylint: disable=broad-except - mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}") - if line: - mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...") - - -async def stream_and_forward_stderr( - mcp_process: subprocess.Popen, read_chunk_size: int = 10 -) -> None: - """Read from the mcp_process stderr and write to sys.stderr""" - loop = asyncio.get_event_loop() - - while True: - # Read chunks asynchronously - chunk = await loop.run_in_executor( - None, lambda: mcp_process.stderr.read(read_chunk_size) - ) - - MCP_LOG_FILE.buffer.write(chunk) - MCP_LOG_FILE.buffer.flush() - - -async def _intercept_request( - session_id: str, mcp_process: subprocess.Popen, line: bytes -) -> None: - """ - Process a line of input from stdin, decode it and check for guardrails. - If the request is blocked, it returns a message indicating the block reason - otherwise it forwards the request to mcp_process stdin. - """ - session = session_store.get_session(session_id) - if session.attributes.verbose: - mcp_log(f"[INFO] client -> server: {line}") - - # Try to decode and parse as JSON to check for tool calls - try: - text = line.decode(UTF_8) - request_body = json.loads(text) - except json.JSONDecodeError as je: - mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}") - mcp_log(f"[ERROR] Problematic line: {line[:200]}...") - return - update_session_from_request(session, request_body) - # Refresh guardrails - await session.load_guardrails() - - hook_tool_call_result = {} - is_blocked = False - if request_body.get(MCP_METHOD) == MCP_TOOL_CALL: - hook_tool_call_result, is_blocked = await hook_tool_call( - session_id, session_store, request_body - ) - elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS: - hook_tool_call_result, is_blocked = await hook_tool_call( - session_id=session_id, - session_store=session_store, - request_body={ - "id": request_body.get("id"), - "method": MCP_LIST_TOOLS, - "params": {"name": MCP_LIST_TOOLS, "arguments": {}}, - }, - ) - if is_blocked: - sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result)) - sys.stdout.buffer.flush() - return - mcp_process.stdin.write(write_as_utf8_bytes(request_body)) - mcp_process.stdin.flush() - - -async def wait_for_stdin_input( - loop: asyncio.AbstractEventLoop, stdin_fd: int -) -> tuple[bytes | None, str]: - """ - Platform-specific implementation to wait for and read input from stdin. - - Args: - loop: The asyncio event loop - stdin_fd: The file descriptor for stdin - - Returns: - tuple[bytes | None, str]: A tuple containing: - - The data read from stdin or None - - Status: 'eof' if EOF detected, 'data' if data available, 'wait' if no data yet - """ - if platform.system() == "Windows": - # On Windows, we can't use select for stdin - # Instead, we'll use a brief sleep and then try to read - await asyncio.sleep(0.01) - try: - chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096)) - if not chunk: # Empty bytes means EOF - return None, STATUS_EOF - return chunk, STATUS_DATA - except (BlockingIOError, OSError): - # No data available yet - return None, STATUS_WAIT - else: - # On Unix-like systems, use select - ready, _, _ = await loop.run_in_executor( - None, lambda: select.select([stdin_fd], [], [], 0.1) - ) - - if not ready: - # No input available, yield to other tasks - await asyncio.sleep(0.01) - return None, STATUS_WAIT - - # Read available data - chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096)) - if not chunk: # Empty bytes means EOF - return None, STATUS_EOF - return chunk, STATUS_DATA - - -async def run_stdio_input_loop( - session_id: str, - mcp_process: subprocess.Popen, - stdout_task: asyncio.Task, - stderr_task: asyncio.Task, -) -> None: - """Handle standard input, intercept call and forward requests to mcp_process stdin.""" - loop = asyncio.get_event_loop() - stdin_fd = sys.stdin.fileno() - buffer = b"" - - # Set stdin to non-blocking mode - os.set_blocking(stdin_fd, False) - - try: - while True: - # Get input using platform-specific method - chunk, status = await wait_for_stdin_input(loop, stdin_fd) - - if status == STATUS_EOF: - # EOF detected, break the loop - break - elif status == STATUS_WAIT: - # No data available yet, continue polling - continue - elif status == STATUS_DATA: - # We got some data, process it - buffer += chunk - - # Process complete lines - while b"\n" in buffer: - line, buffer = buffer.split(b"\n", 1) - if not line: - continue - - await _intercept_request(session_id, mcp_process, line) - except (BrokenPipeError, KeyboardInterrupt): - # Broken pipe = client disappeared, just start shutdown - mcp_log("Client disconnected or keyboard interrupt") - finally: - # Close stdin - if mcp_process.stdin: - mcp_process.stdin.close() - - # Process any remaining data - while b"\n" in buffer: - line, buffer = buffer.split(b"\n", 1) - if line: - await _intercept_request(session_id, mcp_process, line) - - # Terminate process if needed - if mcp_process.poll() is None: - mcp_process.terminate() - try: - await asyncio.wait_for( - loop.run_in_executor(None, mcp_process.wait), timeout=2 - ) - except asyncio.TimeoutError: - mcp_process.kill() - - # Cancel I/O tasks - stdout_task.cancel() - stderr_task.cancel() - - # Final flush - sys.stdout.flush() - - -def split_args(args: list[str] = None) -> tuple[list[str], list[str]]: - """ - Splits CLI arguments into two parts: - 1. Arguments intended for the MCP gateway (everything before `--exec`) - 2. Arguments for the underlying MCP server (everything after `--exec`) - - Parameters: - args (list[str]): The list of CLI arguments. - - Returns: - Tuple[list[str], list[str]]: A tuple containing (mcp_gateway_args, mcp_server_command_args) - """ - if not args: - mcp_log("[ERROR] No arguments provided.") - sys.exit(1) - - try: - exec_index = args.index("--exec") - except ValueError: - mcp_log("[ERROR] '--exec' flag not found in arguments.") - sys.exit(1) - - mcp_gateway_args = args[:exec_index] - mcp_server_command_args = args[exec_index + 1 :] - - if not mcp_server_command_args: - mcp_log("[ERROR] No arguments provided after '--exec'.") - sys.exit(1) - - return mcp_gateway_args, mcp_server_command_args - - -async def execute(args: list[str] = None): - """Main function to execute the MCP gateway.""" - if "INVARIANT_API_KEY" not in os.environ: - mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.") - sys.exit(1) - - mcp_log("[INFO] Running with Python version:", sys.version) - - mcp_gateway_args, mcp_server_command_args = split_args(args) - session_id = generate_session_id() - await session_store.initialize_session( - session_id, - McpAttributes.from_cli_args(mcp_gateway_args), - ) - - mcp_process = subprocess.Popen( - mcp_server_command_args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - bufsize=0, - ) - - # Start async tasks for stdout and stderr - stdout_task = asyncio.create_task( - stream_and_forward_stdout(session_id, mcp_process) - ) - stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process)) - - # Handle forwarding stdin and intercept tool calls - await run_stdio_input_loop(session_id, mcp_process, stdout_task, stderr_task) diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/mcp/mcp_sessions_manager.py similarity index 100% rename from gateway/common/mcp_sessions_manager.py rename to gateway/mcp/mcp_sessions_manager.py diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py new file mode 100644 index 0000000..3cc8d81 --- /dev/null +++ b/gateway/mcp/mcp_transport_base.py @@ -0,0 +1,116 @@ +""" +MCP Transport Strategy Pattern Implementation + +This module defines an abstract base class for MCP transports. +""" + +from abc import ABC, abstractmethod +from typing import Any, Tuple + +from gateway.common.constants import ( + MCP_METHOD, + MCP_TOOL_CALL, + MCP_LIST_TOOLS, +) +from gateway.mcp.mcp_sessions_manager import McpSessionsManager +from gateway.mcp.utils import ( + hook_tool_call, + intercept_response, + update_mcp_server_in_session_metadata, + update_session_from_request, +) + + +class MCPTransportBase(ABC): + """ + Abstract base class for MCP transport strategies. + + This class defines the common interface and shared functionality for all MCP transports, + using the Template Method pattern for request/response processing. + """ + + def __init__(self, session_store: McpSessionsManager): + self.session_store = session_store + + async def process_outgoing_request( + self, session_id: str, request_data: dict[str, Any] + ) -> Tuple[dict[str, Any], bool]: + """ + Template method for processing outgoing requests to MCP server. + + Returns: + Tuple[processed_request_data, is_blocked] + """ + # Update session with request information + session = self.session_store.get_session(session_id) + update_session_from_request(session, request_data) + + # Refresh guardrails + await session.load_guardrails() + + # Check if request should be intercepted for guardrails + if self._should_intercept_request(request_data): + return await self._intercept_outgoing_request(session_id, request_data) + + return request_data, False + + async def process_incoming_response( + self, session_id: str, response_data: dict[str, Any] + ) -> Tuple[dict[str, Any], bool]: + """ + Template method for processing incoming responses from MCP server. + + Returns: + Tuple[processed_response, is_blocked] + """ + # Update session with server information + session = self.session_store.get_session(session_id) + update_mcp_server_in_session_metadata(session, response_data) + + # Intercept and apply guardrails to response + return await intercept_response(session_id, self.session_store, response_data) + + def _should_intercept_request(self, request_data: dict[str, Any]) -> bool: + """Check if request should be intercepted for guardrails.""" + method = request_data.get(MCP_METHOD) + return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS] + + async def _intercept_outgoing_request( + self, session_id: str, request_data: dict[str, Any] + ) -> Tuple[dict[str, Any], bool]: + """Common request interception logic for guardrails.""" + method = request_data.get(MCP_METHOD) + + interception_result = request_data + is_blocked = False + if method == MCP_TOOL_CALL: + interception_result, is_blocked = await hook_tool_call( + session_id, self.session_store, request_data + ) + elif method == MCP_LIST_TOOLS: + interception_result, is_blocked = await hook_tool_call( + session_id=session_id, + session_store=self.session_store, + request_body={ + "id": request_data.get("id"), + "method": MCP_LIST_TOOLS, + "params": {"name": MCP_LIST_TOOLS, "arguments": {}}, + }, + ) + + return interception_result, is_blocked + + def _is_initialization_request(self, request_data: dict[str, Any]) -> bool: + """Check if request is an initialization request.""" + return ( + request_data.get("method") in ["initialize", "notifications/initialized"] + and "jsonrpc" in request_data + ) + + @abstractmethod + async def initialize_session(self, *args, **kwargs) -> str: + """Initialize a session for this transport type.""" + + @abstractmethod + async def handle_communication(self, *args, **kwargs) -> Any: + """Handle the main communication for this transport.""" diff --git a/gateway/routes/mcp_sse.py b/gateway/mcp/sse.py similarity index 99% rename from gateway/routes/mcp_sse.py rename to gateway/mcp/sse.py index 3c27ebf..aee7b44 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/mcp/sse.py @@ -17,11 +17,11 @@ from gateway.common.constants import ( MCP_LIST_TOOLS, UTF_8, ) -from gateway.common.mcp_sessions_manager import ( +from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.common.mcp_utils import ( +from gateway.mcp.utils import ( get_mcp_server_base_url, hook_tool_call, intercept_response, diff --git a/gateway/mcp/stdio.py b/gateway/mcp/stdio.py new file mode 100644 index 0000000..d446048 --- /dev/null +++ b/gateway/mcp/stdio.py @@ -0,0 +1,314 @@ +"""Gateway for MCP (Model Context Protocol) integration with Invariant.""" + +import asyncio +import json +import os +import platform +import select +import subprocess +import sys +from typing import Optional, Tuple + +from gateway.common.constants import ( + UTF_8, +) +from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_sessions_manager import ( + McpAttributes, + McpSessionsManager, +) +from gateway.mcp.utils import ( + generate_session_id, +) +from gateway.mcp.log import mcp_log, MCP_LOG_FILE + +STATUS_EOF = "eof" +STATUS_DATA = "data" +STATUS_WAIT = "wait" +mcp_sessions_manager = McpSessionsManager() + + +class StdioTransport(MCPTransportBase): + """ + STDIO transport implementation for MCP communication. + Handles subprocess-based communication with stdin/stdout/stderr. + """ + + def __init__(self, session_store: McpSessionsManager): + super().__init__(session_store) + self.mcp_process: subprocess.Popen = None + + async def initialize_session(self, *args, **kwargs) -> str: + """Initialize session for stdio transport.""" + session_attributes: McpAttributes = kwargs.get("session_attributes") + session_id = generate_session_id() + await self.session_store.initialize_session(session_id, session_attributes) + mcp_log(f"Created stdio session with ID: {session_id}") + return session_id + + def start_mcp_process(self, mcp_server_command_args: list) -> subprocess.Popen: + """Start the MCP server subprocess.""" + self.mcp_process = subprocess.Popen( + mcp_server_command_args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=0, + ) + mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}") + return self.mcp_process + + async def handle_communication(self, *args, **kwargs) -> None: + """Handle stdio communication loop.""" + session_id: str = kwargs.get("session_id") + mcp_process: subprocess.Popen = kwargs.get("mcp_process") + if not session_id or not mcp_process: + raise ValueError( + "session_id and mcp_process are required for stdio transport" + ) + + self.mcp_process = mcp_process + + # Start async tasks for stdout and stderr + stdout_task = asyncio.create_task(self._stream_and_forward_stdout(session_id)) + stderr_task = asyncio.create_task(self._stream_and_forward_stderr()) + + try: + # Handle stdin input loop + await self._run_stdio_input_loop(session_id) + finally: + # Cleanup + if self.mcp_process and self.mcp_process.stdin: + self.mcp_process.stdin.close() + + # Terminate process if needed + if self.mcp_process and self.mcp_process.poll() is None: + self.mcp_process.terminate() + try: + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor( + None, self.mcp_process.wait + ), + timeout=2, + ) + except asyncio.TimeoutError: + self.mcp_process.kill() + + # Cancel I/O tasks + stdout_task.cancel() + stderr_task.cancel() + + # Final flush + sys.stdout.flush() + + async def _stream_and_forward_stdout(self, session_id: str) -> None: + """Read from MCP process stdout, apply guardrails and forward to sys.stdout.""" + loop = asyncio.get_event_loop() + + while True: + if self.mcp_process.poll() is not None: + mcp_log( + f"[ERROR] MCP process terminated with code: {self.mcp_process.poll()}" + ) + break + + line = await loop.run_in_executor(None, self.mcp_process.stdout.readline) + if not line: + break + + try: + decoded_line = line.decode(UTF_8).strip() + if not decoded_line: + continue + + session = self.session_store.get_session(session_id) + if session.attributes.verbose: + mcp_log(f"[INFO] server -> client: {decoded_line}") + + response_body = json.loads(decoded_line) + processed_response, _ = await self.process_incoming_response( + session_id, response_body + ) + + sys.stdout.buffer.write(self._serialize_to_bytes(processed_response)) + sys.stdout.buffer.flush() + except Exception as e: # pylint: disable=broad-except + mcp_log(f"[ERROR] Error in _stream_and_forward_stdout: {str(e)}") + if line: + mcp_log(f"[ERROR] Problematic line: {line[:200]}...") + + async def _stream_and_forward_stderr(self) -> None: + """Read from MCP process stderr and write to log file.""" + loop = asyncio.get_event_loop() + + while True: + chunk = await loop.run_in_executor( + None, lambda: self.mcp_process.stderr.read(10) + ) + if not chunk: + break + MCP_LOG_FILE.buffer.write(chunk) + MCP_LOG_FILE.buffer.flush() + + async def _run_stdio_input_loop(self, session_id: str) -> None: + """Handle standard input, intercept calls and forward requests to MCP process stdin.""" + loop = asyncio.get_event_loop() + stdin_fd = sys.stdin.fileno() + buffer = b"" + + # Set stdin to non-blocking mode + os.set_blocking(stdin_fd, False) + + try: + while True: + # Get input using platform-specific method + chunk, status = await self._wait_for_stdin_input(loop, stdin_fd) + + if status == STATUS_EOF: + break + elif status == STATUS_WAIT: + continue + elif status == STATUS_DATA: + buffer += chunk + + # Process complete lines + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + if not line: + continue + + await self._process_stdin_line(session_id, line) + + except (BrokenPipeError, KeyboardInterrupt): + mcp_log("Client disconnected or keyboard interrupt") + finally: + # Process any remaining data + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + if line: + await self._process_stdin_line(session_id, line) + + async def _process_stdin_line(self, session_id: str, line: bytes) -> None: + """Process a line of input from stdin.""" + session = self.session_store.get_session(session_id) + if session.attributes.verbose: + mcp_log(f"[INFO] client -> server: {line}") + + try: + text = line.decode(UTF_8) + request_body = json.loads(text) + except json.JSONDecodeError as je: + mcp_log(f"[ERROR] JSON decode error: {str(je)}") + mcp_log(f"[ERROR] Problematic line: {line[:200]}...") + return + + processed_request, is_blocked = await self.process_outgoing_request( + session_id, request_body + ) + + if is_blocked: + sys.stdout.buffer.write(self._serialize_to_bytes(processed_request)) + sys.stdout.buffer.flush() + return + self.mcp_process.stdin.write(self._serialize_to_bytes(request_body)) + self.mcp_process.stdin.flush() + + async def _wait_for_stdin_input( + self, loop: asyncio.AbstractEventLoop, stdin_fd: int + ) -> Tuple[Optional[bytes], str]: + """Platform-specific implementation to wait for and read input from stdin.""" + if platform.system() == "Windows": + await asyncio.sleep(0.01) + try: + chunk = await loop.run_in_executor( + None, lambda: os.read(stdin_fd, 4096) + ) + if not chunk: + return None, STATUS_EOF + return chunk, STATUS_DATA + except (BlockingIOError, OSError): + return None, STATUS_WAIT + else: + # Unix-like systems + ready, _, _ = await loop.run_in_executor( + None, lambda: select.select([stdin_fd], [], [], 0.1) + ) + + if not ready: + await asyncio.sleep(0.01) + return None, STATUS_WAIT + + chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096)) + if not chunk: + return None, STATUS_EOF + return chunk, STATUS_DATA + + def _serialize_to_bytes(self, data: dict) -> bytes: + """Serialize dict to bytes using UTF-8 encoding.""" + return json.dumps(data).encode(UTF_8) + b"\n" + + +async def create_stdio_transport_and_execute( + session_store: McpSessionsManager, + session_attributes: McpAttributes, + mcp_server_command_args: list, +) -> None: + """Integration function for stdio execution.""" + stdio_transport = StdioTransport(session_store=session_store) + + session_id = await stdio_transport.initialize_session( + session_attributes=session_attributes + ) + + await stdio_transport.handle_communication( + session_id=session_id, + mcp_process=stdio_transport.start_mcp_process(mcp_server_command_args), + ) + + +def split_args(args: list[str] = None) -> tuple[list[str], list[str]]: + """ + Splits CLI arguments into two parts: + 1. Arguments intended for the MCP gateway (everything before `--exec`) + 2. Arguments for the underlying MCP server (everything after `--exec`) + """ + if not args: + mcp_log("[ERROR] No arguments provided.") + sys.exit(1) + + try: + exec_index = args.index("--exec") + except ValueError: + mcp_log("[ERROR] '--exec' flag not found in arguments.") + sys.exit(1) + + mcp_gateway_args = args[:exec_index] + mcp_server_command_args = args[exec_index + 1 :] + + if not mcp_server_command_args: + mcp_log("[ERROR] No arguments provided after '--exec'.") + sys.exit(1) + + return mcp_gateway_args, mcp_server_command_args + + +async def execute(args: list[str] = None): + """Main function to execute the MCP gateway using transport strategy.""" + if "INVARIANT_API_KEY" not in os.environ: + mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.") + sys.exit(1) + + mcp_log("[INFO] Running with Python version:", sys.version) + + # Parse arguments + mcp_gateway_args, mcp_server_command_args = split_args(args) + + # Create session store and attributes + session_attributes = McpAttributes.from_cli_args(mcp_gateway_args) + + # Use stdio transport strategy + await create_stdio_transport_and_execute( + session_store=mcp_sessions_manager, + session_attributes=session_attributes, + mcp_server_command_args=mcp_server_command_args, + ) diff --git a/gateway/routes/mcp_streamable.py b/gateway/mcp/streamable.py similarity index 99% rename from gateway/routes/mcp_streamable.py rename to gateway/mcp/streamable.py index 4d806d9..0716565 100644 --- a/gateway/routes/mcp_streamable.py +++ b/gateway/mcp/streamable.py @@ -15,11 +15,11 @@ from gateway.common.constants import ( MCP_TOOL_CALL, UTF_8, ) -from gateway.common.mcp_sessions_manager import ( +from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.common.mcp_utils import ( +from gateway.mcp.utils import ( generate_session_id, get_mcp_server_base_url, hook_tool_call, diff --git a/gateway/common/mcp_utils.py b/gateway/mcp/utils.py similarity index 99% rename from gateway/common/mcp_utils.py rename to gateway/mcp/utils.py index 208ad19..2903705 100644 --- a/gateway/common/mcp_utils.py +++ b/gateway/mcp/utils.py @@ -22,7 +22,7 @@ from gateway.common.constants import ( MCP_TOOL_CALL, ) from gateway.common.guardrails import GuardrailAction -from gateway.common.mcp_sessions_manager import ( +from gateway.mcp.mcp_sessions_manager import ( McpSession, McpSessionsManager, ) diff --git a/gateway/serve.py b/gateway/serve.py index ebf8f58..8e752c1 100644 --- a/gateway/serve.py +++ b/gateway/serve.py @@ -7,8 +7,8 @@ from starlette_compress import CompressMiddleware from gateway.routes.anthropic import gateway as anthropic_gateway from gateway.routes.gemini import gateway as gemini_gateway from gateway.routes.open_ai import gateway as open_ai_gateway -from gateway.routes.mcp_sse import gateway as mcp_sse_gateway -from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway +from gateway.mcp.sse import gateway as mcp_sse_gateway +from gateway.mcp.streamable import gateway as mcp_streamable_gateway app = fastapi.app = fastapi.FastAPI( docs_url="/api/v1/gateway/docs",