Move MCP related routes to the MCP directory and introduce the MCPTransportBase class.

This commit is contained in:
Hemang
2025-06-03 13:59:34 +02:00
committed by Hemang Sarkar
parent e8106776b4
commit 7efd15e2a9
9 changed files with 439 additions and 326 deletions
+2 -2
View File
@@ -9,7 +9,7 @@ import time
from typing import Optional
from gateway.mcp import mcp
from gateway.mcp import stdio as mcp_stdio
from gateway.mcp.log import mcp_log
@@ -235,7 +235,7 @@ def main():
sys.exit(1)
if verb == "mcp":
return asyncio.run(mcp.execute(sys.argv[2:]))
return asyncio.run(mcp_stdio.execute(sys.argv[2:]))
if verb == "server":
if len(sys.argv) < 3:
-317
View File
@@ -1,317 +0,0 @@
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
import asyncio
import json
import os
import platform
import select
import subprocess
import sys
from gateway.common.constants import (
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
UTF_8,
)
from gateway.common.mcp_sessions_manager import (
McpAttributes,
McpSessionsManager,
)
from gateway.common.mcp_utils import (
generate_session_id,
hook_tool_call,
intercept_response,
update_mcp_server_in_session_metadata,
update_session_from_request,
)
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
STATUS_EOF = "eof"
STATUS_DATA = "data"
STATUS_WAIT = "wait"
session_store = McpSessionsManager()
def write_as_utf8_bytes(data: dict) -> bytes:
"""Serializes dict to bytes using UTF-8 encoding."""
return json.dumps(data).encode(UTF_8) + b"\n"
async def stream_and_forward_stdout(
session_id: str, mcp_process: subprocess.Popen
) -> None:
"""Read from the mcp_process stdout, apply guardrails and forward to sys.stdout"""
loop = asyncio.get_event_loop()
while True:
if mcp_process.poll() is not None:
mcp_log(f"[ERROR] MCP process terminated with code: {mcp_process.poll()}")
break
line = await loop.run_in_executor(None, mcp_process.stdout.readline)
if not line:
break
try:
# Process complete JSON lines
decoded_line = line.decode(UTF_8).strip()
if not decoded_line:
continue
session = session_store.get_session(session_id)
if session.attributes.verbose:
mcp_log(f"[INFO] server -> client: {decoded_line}")
response_body = json.loads(decoded_line)
update_mcp_server_in_session_metadata(session, response_body)
intercept_response_result, _ = await intercept_response(
session_id, session_store, response_body
)
# Write and flush immediately
sys.stdout.buffer.write(write_as_utf8_bytes(intercept_response_result))
sys.stdout.buffer.flush()
except Exception as e: # pylint: disable=broad-except
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
if line:
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
async def stream_and_forward_stderr(
mcp_process: subprocess.Popen, read_chunk_size: int = 10
) -> None:
"""Read from the mcp_process stderr and write to sys.stderr"""
loop = asyncio.get_event_loop()
while True:
# Read chunks asynchronously
chunk = await loop.run_in_executor(
None, lambda: mcp_process.stderr.read(read_chunk_size)
)
MCP_LOG_FILE.buffer.write(chunk)
MCP_LOG_FILE.buffer.flush()
async def _intercept_request(
session_id: str, mcp_process: subprocess.Popen, line: bytes
) -> None:
"""
Process a line of input from stdin, decode it and check for guardrails.
If the request is blocked, it returns a message indicating the block reason
otherwise it forwards the request to mcp_process stdin.
"""
session = session_store.get_session(session_id)
if session.attributes.verbose:
mcp_log(f"[INFO] client -> server: {line}")
# Try to decode and parse as JSON to check for tool calls
try:
text = line.decode(UTF_8)
request_body = json.loads(text)
except json.JSONDecodeError as je:
mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}")
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
return
update_session_from_request(session, request_body)
# Refresh guardrails
await session.load_guardrails()
hook_tool_call_result = {}
is_blocked = False
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id, session_store, request_body
)
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id=session_id,
session_store=session_store,
request_body={
"id": request_body.get("id"),
"method": MCP_LIST_TOOLS,
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
},
)
if is_blocked:
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
sys.stdout.buffer.flush()
return
mcp_process.stdin.write(write_as_utf8_bytes(request_body))
mcp_process.stdin.flush()
async def wait_for_stdin_input(
loop: asyncio.AbstractEventLoop, stdin_fd: int
) -> tuple[bytes | None, str]:
"""
Platform-specific implementation to wait for and read input from stdin.
Args:
loop: The asyncio event loop
stdin_fd: The file descriptor for stdin
Returns:
tuple[bytes | None, str]: A tuple containing:
- The data read from stdin or None
- Status: 'eof' if EOF detected, 'data' if data available, 'wait' if no data yet
"""
if platform.system() == "Windows":
# On Windows, we can't use select for stdin
# Instead, we'll use a brief sleep and then try to read
await asyncio.sleep(0.01)
try:
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
if not chunk: # Empty bytes means EOF
return None, STATUS_EOF
return chunk, STATUS_DATA
except (BlockingIOError, OSError):
# No data available yet
return None, STATUS_WAIT
else:
# On Unix-like systems, use select
ready, _, _ = await loop.run_in_executor(
None, lambda: select.select([stdin_fd], [], [], 0.1)
)
if not ready:
# No input available, yield to other tasks
await asyncio.sleep(0.01)
return None, STATUS_WAIT
# Read available data
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
if not chunk: # Empty bytes means EOF
return None, STATUS_EOF
return chunk, STATUS_DATA
async def run_stdio_input_loop(
session_id: str,
mcp_process: subprocess.Popen,
stdout_task: asyncio.Task,
stderr_task: asyncio.Task,
) -> None:
"""Handle standard input, intercept call and forward requests to mcp_process stdin."""
loop = asyncio.get_event_loop()
stdin_fd = sys.stdin.fileno()
buffer = b""
# Set stdin to non-blocking mode
os.set_blocking(stdin_fd, False)
try:
while True:
# Get input using platform-specific method
chunk, status = await wait_for_stdin_input(loop, stdin_fd)
if status == STATUS_EOF:
# EOF detected, break the loop
break
elif status == STATUS_WAIT:
# No data available yet, continue polling
continue
elif status == STATUS_DATA:
# We got some data, process it
buffer += chunk
# Process complete lines
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
await _intercept_request(session_id, mcp_process, line)
except (BrokenPipeError, KeyboardInterrupt):
# Broken pipe = client disappeared, just start shutdown
mcp_log("Client disconnected or keyboard interrupt")
finally:
# Close stdin
if mcp_process.stdin:
mcp_process.stdin.close()
# Process any remaining data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if line:
await _intercept_request(session_id, mcp_process, line)
# Terminate process if needed
if mcp_process.poll() is None:
mcp_process.terminate()
try:
await asyncio.wait_for(
loop.run_in_executor(None, mcp_process.wait), timeout=2
)
except asyncio.TimeoutError:
mcp_process.kill()
# Cancel I/O tasks
stdout_task.cancel()
stderr_task.cancel()
# Final flush
sys.stdout.flush()
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
"""
Splits CLI arguments into two parts:
1. Arguments intended for the MCP gateway (everything before `--exec`)
2. Arguments for the underlying MCP server (everything after `--exec`)
Parameters:
args (list[str]): The list of CLI arguments.
Returns:
Tuple[list[str], list[str]]: A tuple containing (mcp_gateway_args, mcp_server_command_args)
"""
if not args:
mcp_log("[ERROR] No arguments provided.")
sys.exit(1)
try:
exec_index = args.index("--exec")
except ValueError:
mcp_log("[ERROR] '--exec' flag not found in arguments.")
sys.exit(1)
mcp_gateway_args = args[:exec_index]
mcp_server_command_args = args[exec_index + 1 :]
if not mcp_server_command_args:
mcp_log("[ERROR] No arguments provided after '--exec'.")
sys.exit(1)
return mcp_gateway_args, mcp_server_command_args
async def execute(args: list[str] = None):
"""Main function to execute the MCP gateway."""
if "INVARIANT_API_KEY" not in os.environ:
mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.")
sys.exit(1)
mcp_log("[INFO] Running with Python version:", sys.version)
mcp_gateway_args, mcp_server_command_args = split_args(args)
session_id = generate_session_id()
await session_store.initialize_session(
session_id,
McpAttributes.from_cli_args(mcp_gateway_args),
)
mcp_process = subprocess.Popen(
mcp_server_command_args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=0,
)
# Start async tasks for stdout and stderr
stdout_task = asyncio.create_task(
stream_and_forward_stdout(session_id, mcp_process)
)
stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process))
# Handle forwarding stdin and intercept tool calls
await run_stdio_input_loop(session_id, mcp_process, stdout_task, stderr_task)
+116
View File
@@ -0,0 +1,116 @@
"""
MCP Transport Strategy Pattern Implementation
This module defines an abstract base class for MCP transports.
"""
from abc import ABC, abstractmethod
from typing import Any, Tuple
from gateway.common.constants import (
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
)
from gateway.mcp.mcp_sessions_manager import McpSessionsManager
from gateway.mcp.utils import (
hook_tool_call,
intercept_response,
update_mcp_server_in_session_metadata,
update_session_from_request,
)
class MCPTransportBase(ABC):
"""
Abstract base class for MCP transport strategies.
This class defines the common interface and shared functionality for all MCP transports,
using the Template Method pattern for request/response processing.
"""
def __init__(self, session_store: McpSessionsManager):
self.session_store = session_store
async def process_outgoing_request(
self, session_id: str, request_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
"""
Template method for processing outgoing requests to MCP server.
Returns:
Tuple[processed_request_data, is_blocked]
"""
# Update session with request information
session = self.session_store.get_session(session_id)
update_session_from_request(session, request_data)
# Refresh guardrails
await session.load_guardrails()
# Check if request should be intercepted for guardrails
if self._should_intercept_request(request_data):
return await self._intercept_outgoing_request(session_id, request_data)
return request_data, False
async def process_incoming_response(
self, session_id: str, response_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
"""
Template method for processing incoming responses from MCP server.
Returns:
Tuple[processed_response, is_blocked]
"""
# Update session with server information
session = self.session_store.get_session(session_id)
update_mcp_server_in_session_metadata(session, response_data)
# Intercept and apply guardrails to response
return await intercept_response(session_id, self.session_store, response_data)
def _should_intercept_request(self, request_data: dict[str, Any]) -> bool:
"""Check if request should be intercepted for guardrails."""
method = request_data.get(MCP_METHOD)
return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS]
async def _intercept_outgoing_request(
self, session_id: str, request_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
"""Common request interception logic for guardrails."""
method = request_data.get(MCP_METHOD)
interception_result = request_data
is_blocked = False
if method == MCP_TOOL_CALL:
interception_result, is_blocked = await hook_tool_call(
session_id, self.session_store, request_data
)
elif method == MCP_LIST_TOOLS:
interception_result, is_blocked = await hook_tool_call(
session_id=session_id,
session_store=self.session_store,
request_body={
"id": request_data.get("id"),
"method": MCP_LIST_TOOLS,
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
},
)
return interception_result, is_blocked
def _is_initialization_request(self, request_data: dict[str, Any]) -> bool:
"""Check if request is an initialization request."""
return (
request_data.get("method") in ["initialize", "notifications/initialized"]
and "jsonrpc" in request_data
)
@abstractmethod
async def initialize_session(self, *args, **kwargs) -> str:
"""Initialize a session for this transport type."""
@abstractmethod
async def handle_communication(self, *args, **kwargs) -> Any:
"""Handle the main communication for this transport."""
@@ -17,11 +17,11 @@ from gateway.common.constants import (
MCP_LIST_TOOLS,
UTF_8,
)
from gateway.common.mcp_sessions_manager import (
from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.common.mcp_utils import (
from gateway.mcp.utils import (
get_mcp_server_base_url,
hook_tool_call,
intercept_response,
+314
View File
@@ -0,0 +1,314 @@
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
import asyncio
import json
import os
import platform
import select
import subprocess
import sys
from typing import Optional, Tuple
from gateway.common.constants import (
UTF_8,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_sessions_manager import (
McpAttributes,
McpSessionsManager,
)
from gateway.mcp.utils import (
generate_session_id,
)
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
STATUS_EOF = "eof"
STATUS_DATA = "data"
STATUS_WAIT = "wait"
mcp_sessions_manager = McpSessionsManager()
class StdioTransport(MCPTransportBase):
"""
STDIO transport implementation for MCP communication.
Handles subprocess-based communication with stdin/stdout/stderr.
"""
def __init__(self, session_store: McpSessionsManager):
super().__init__(session_store)
self.mcp_process: subprocess.Popen = None
async def initialize_session(self, *args, **kwargs) -> str:
"""Initialize session for stdio transport."""
session_attributes: McpAttributes = kwargs.get("session_attributes")
session_id = generate_session_id()
await self.session_store.initialize_session(session_id, session_attributes)
mcp_log(f"Created stdio session with ID: {session_id}")
return session_id
def start_mcp_process(self, mcp_server_command_args: list) -> subprocess.Popen:
"""Start the MCP server subprocess."""
self.mcp_process = subprocess.Popen(
mcp_server_command_args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=0,
)
mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}")
return self.mcp_process
async def handle_communication(self, *args, **kwargs) -> None:
"""Handle stdio communication loop."""
session_id: str = kwargs.get("session_id")
mcp_process: subprocess.Popen = kwargs.get("mcp_process")
if not session_id or not mcp_process:
raise ValueError(
"session_id and mcp_process are required for stdio transport"
)
self.mcp_process = mcp_process
# Start async tasks for stdout and stderr
stdout_task = asyncio.create_task(self._stream_and_forward_stdout(session_id))
stderr_task = asyncio.create_task(self._stream_and_forward_stderr())
try:
# Handle stdin input loop
await self._run_stdio_input_loop(session_id)
finally:
# Cleanup
if self.mcp_process and self.mcp_process.stdin:
self.mcp_process.stdin.close()
# Terminate process if needed
if self.mcp_process and self.mcp_process.poll() is None:
self.mcp_process.terminate()
try:
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(
None, self.mcp_process.wait
),
timeout=2,
)
except asyncio.TimeoutError:
self.mcp_process.kill()
# Cancel I/O tasks
stdout_task.cancel()
stderr_task.cancel()
# Final flush
sys.stdout.flush()
async def _stream_and_forward_stdout(self, session_id: str) -> None:
"""Read from MCP process stdout, apply guardrails and forward to sys.stdout."""
loop = asyncio.get_event_loop()
while True:
if self.mcp_process.poll() is not None:
mcp_log(
f"[ERROR] MCP process terminated with code: {self.mcp_process.poll()}"
)
break
line = await loop.run_in_executor(None, self.mcp_process.stdout.readline)
if not line:
break
try:
decoded_line = line.decode(UTF_8).strip()
if not decoded_line:
continue
session = self.session_store.get_session(session_id)
if session.attributes.verbose:
mcp_log(f"[INFO] server -> client: {decoded_line}")
response_body = json.loads(decoded_line)
processed_response, _ = await self.process_incoming_response(
session_id, response_body
)
sys.stdout.buffer.write(self._serialize_to_bytes(processed_response))
sys.stdout.buffer.flush()
except Exception as e: # pylint: disable=broad-except
mcp_log(f"[ERROR] Error in _stream_and_forward_stdout: {str(e)}")
if line:
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
async def _stream_and_forward_stderr(self) -> None:
"""Read from MCP process stderr and write to log file."""
loop = asyncio.get_event_loop()
while True:
chunk = await loop.run_in_executor(
None, lambda: self.mcp_process.stderr.read(10)
)
if not chunk:
break
MCP_LOG_FILE.buffer.write(chunk)
MCP_LOG_FILE.buffer.flush()
async def _run_stdio_input_loop(self, session_id: str) -> None:
"""Handle standard input, intercept calls and forward requests to MCP process stdin."""
loop = asyncio.get_event_loop()
stdin_fd = sys.stdin.fileno()
buffer = b""
# Set stdin to non-blocking mode
os.set_blocking(stdin_fd, False)
try:
while True:
# Get input using platform-specific method
chunk, status = await self._wait_for_stdin_input(loop, stdin_fd)
if status == STATUS_EOF:
break
elif status == STATUS_WAIT:
continue
elif status == STATUS_DATA:
buffer += chunk
# Process complete lines
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
await self._process_stdin_line(session_id, line)
except (BrokenPipeError, KeyboardInterrupt):
mcp_log("Client disconnected or keyboard interrupt")
finally:
# Process any remaining data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if line:
await self._process_stdin_line(session_id, line)
async def _process_stdin_line(self, session_id: str, line: bytes) -> None:
"""Process a line of input from stdin."""
session = self.session_store.get_session(session_id)
if session.attributes.verbose:
mcp_log(f"[INFO] client -> server: {line}")
try:
text = line.decode(UTF_8)
request_body = json.loads(text)
except json.JSONDecodeError as je:
mcp_log(f"[ERROR] JSON decode error: {str(je)}")
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
return
processed_request, is_blocked = await self.process_outgoing_request(
session_id, request_body
)
if is_blocked:
sys.stdout.buffer.write(self._serialize_to_bytes(processed_request))
sys.stdout.buffer.flush()
return
self.mcp_process.stdin.write(self._serialize_to_bytes(request_body))
self.mcp_process.stdin.flush()
async def _wait_for_stdin_input(
self, loop: asyncio.AbstractEventLoop, stdin_fd: int
) -> Tuple[Optional[bytes], str]:
"""Platform-specific implementation to wait for and read input from stdin."""
if platform.system() == "Windows":
await asyncio.sleep(0.01)
try:
chunk = await loop.run_in_executor(
None, lambda: os.read(stdin_fd, 4096)
)
if not chunk:
return None, STATUS_EOF
return chunk, STATUS_DATA
except (BlockingIOError, OSError):
return None, STATUS_WAIT
else:
# Unix-like systems
ready, _, _ = await loop.run_in_executor(
None, lambda: select.select([stdin_fd], [], [], 0.1)
)
if not ready:
await asyncio.sleep(0.01)
return None, STATUS_WAIT
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
if not chunk:
return None, STATUS_EOF
return chunk, STATUS_DATA
def _serialize_to_bytes(self, data: dict) -> bytes:
"""Serialize dict to bytes using UTF-8 encoding."""
return json.dumps(data).encode(UTF_8) + b"\n"
async def create_stdio_transport_and_execute(
session_store: McpSessionsManager,
session_attributes: McpAttributes,
mcp_server_command_args: list,
) -> None:
"""Integration function for stdio execution."""
stdio_transport = StdioTransport(session_store=session_store)
session_id = await stdio_transport.initialize_session(
session_attributes=session_attributes
)
await stdio_transport.handle_communication(
session_id=session_id,
mcp_process=stdio_transport.start_mcp_process(mcp_server_command_args),
)
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
"""
Splits CLI arguments into two parts:
1. Arguments intended for the MCP gateway (everything before `--exec`)
2. Arguments for the underlying MCP server (everything after `--exec`)
"""
if not args:
mcp_log("[ERROR] No arguments provided.")
sys.exit(1)
try:
exec_index = args.index("--exec")
except ValueError:
mcp_log("[ERROR] '--exec' flag not found in arguments.")
sys.exit(1)
mcp_gateway_args = args[:exec_index]
mcp_server_command_args = args[exec_index + 1 :]
if not mcp_server_command_args:
mcp_log("[ERROR] No arguments provided after '--exec'.")
sys.exit(1)
return mcp_gateway_args, mcp_server_command_args
async def execute(args: list[str] = None):
"""Main function to execute the MCP gateway using transport strategy."""
if "INVARIANT_API_KEY" not in os.environ:
mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.")
sys.exit(1)
mcp_log("[INFO] Running with Python version:", sys.version)
# Parse arguments
mcp_gateway_args, mcp_server_command_args = split_args(args)
# Create session store and attributes
session_attributes = McpAttributes.from_cli_args(mcp_gateway_args)
# Use stdio transport strategy
await create_stdio_transport_and_execute(
session_store=mcp_sessions_manager,
session_attributes=session_attributes,
mcp_server_command_args=mcp_server_command_args,
)
@@ -15,11 +15,11 @@ from gateway.common.constants import (
MCP_TOOL_CALL,
UTF_8,
)
from gateway.common.mcp_sessions_manager import (
from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.common.mcp_utils import (
from gateway.mcp.utils import (
generate_session_id,
get_mcp_server_base_url,
hook_tool_call,
@@ -22,7 +22,7 @@ from gateway.common.constants import (
MCP_TOOL_CALL,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
from gateway.mcp.mcp_sessions_manager import (
McpSession,
McpSessionsManager,
)
+2 -2
View File
@@ -7,8 +7,8 @@ from starlette_compress import CompressMiddleware
from gateway.routes.anthropic import gateway as anthropic_gateway
from gateway.routes.gemini import gateway as gemini_gateway
from gateway.routes.open_ai import gateway as open_ai_gateway
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway
from gateway.mcp.sse import gateway as mcp_sse_gateway
from gateway.mcp.streamable import gateway as mcp_streamable_gateway
app = fastapi.app = fastapi.FastAPI(
docs_url="/api/v1/gateway/docs",