mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-01 19:01:41 +02:00
Move MCP related routes to the MCP directory and introduce the MCPTransportBase class.
This commit is contained in:
+2
-2
@@ -9,7 +9,7 @@ import time
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from gateway.mcp import mcp
|
||||
from gateway.mcp import stdio as mcp_stdio
|
||||
from gateway.mcp.log import mcp_log
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
if verb == "mcp":
|
||||
return asyncio.run(mcp.execute(sys.argv[2:]))
|
||||
return asyncio.run(mcp_stdio.execute(sys.argv[2:]))
|
||||
|
||||
if verb == "server":
|
||||
if len(sys.argv) < 3:
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import select
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from gateway.common.constants import (
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpAttributes,
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
generate_session_id,
|
||||
hook_tool_call,
|
||||
intercept_response,
|
||||
update_mcp_server_in_session_metadata,
|
||||
update_session_from_request,
|
||||
)
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
|
||||
STATUS_EOF = "eof"
|
||||
STATUS_DATA = "data"
|
||||
STATUS_WAIT = "wait"
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
|
||||
def write_as_utf8_bytes(data: dict) -> bytes:
|
||||
"""Serializes dict to bytes using UTF-8 encoding."""
|
||||
return json.dumps(data).encode(UTF_8) + b"\n"
|
||||
|
||||
|
||||
async def stream_and_forward_stdout(
|
||||
session_id: str, mcp_process: subprocess.Popen
|
||||
) -> None:
|
||||
"""Read from the mcp_process stdout, apply guardrails and forward to sys.stdout"""
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
if mcp_process.poll() is not None:
|
||||
mcp_log(f"[ERROR] MCP process terminated with code: {mcp_process.poll()}")
|
||||
break
|
||||
|
||||
line = await loop.run_in_executor(None, mcp_process.stdout.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
try:
|
||||
# Process complete JSON lines
|
||||
decoded_line = line.decode(UTF_8).strip()
|
||||
if not decoded_line:
|
||||
continue
|
||||
session = session_store.get_session(session_id)
|
||||
if session.attributes.verbose:
|
||||
mcp_log(f"[INFO] server -> client: {decoded_line}")
|
||||
response_body = json.loads(decoded_line)
|
||||
update_mcp_server_in_session_metadata(session, response_body)
|
||||
|
||||
intercept_response_result, _ = await intercept_response(
|
||||
session_id, session_store, response_body
|
||||
)
|
||||
# Write and flush immediately
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(intercept_response_result))
|
||||
sys.stdout.buffer.flush()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
|
||||
if line:
|
||||
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
|
||||
|
||||
|
||||
async def stream_and_forward_stderr(
|
||||
mcp_process: subprocess.Popen, read_chunk_size: int = 10
|
||||
) -> None:
|
||||
"""Read from the mcp_process stderr and write to sys.stderr"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while True:
|
||||
# Read chunks asynchronously
|
||||
chunk = await loop.run_in_executor(
|
||||
None, lambda: mcp_process.stderr.read(read_chunk_size)
|
||||
)
|
||||
|
||||
MCP_LOG_FILE.buffer.write(chunk)
|
||||
MCP_LOG_FILE.buffer.flush()
|
||||
|
||||
|
||||
async def _intercept_request(
|
||||
session_id: str, mcp_process: subprocess.Popen, line: bytes
|
||||
) -> None:
|
||||
"""
|
||||
Process a line of input from stdin, decode it and check for guardrails.
|
||||
If the request is blocked, it returns a message indicating the block reason
|
||||
otherwise it forwards the request to mcp_process stdin.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
if session.attributes.verbose:
|
||||
mcp_log(f"[INFO] client -> server: {line}")
|
||||
|
||||
# Try to decode and parse as JSON to check for tool calls
|
||||
try:
|
||||
text = line.decode(UTF_8)
|
||||
request_body = json.loads(text)
|
||||
except json.JSONDecodeError as je:
|
||||
mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}")
|
||||
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
|
||||
return
|
||||
update_session_from_request(session, request_body)
|
||||
# Refresh guardrails
|
||||
await session.load_guardrails()
|
||||
|
||||
hook_tool_call_result = {}
|
||||
is_blocked = False
|
||||
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id, session_store, request_body
|
||||
)
|
||||
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
request_body={
|
||||
"id": request_body.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
|
||||
},
|
||||
)
|
||||
if is_blocked:
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
|
||||
sys.stdout.buffer.flush()
|
||||
return
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(request_body))
|
||||
mcp_process.stdin.flush()
|
||||
|
||||
|
||||
async def wait_for_stdin_input(
|
||||
loop: asyncio.AbstractEventLoop, stdin_fd: int
|
||||
) -> tuple[bytes | None, str]:
|
||||
"""
|
||||
Platform-specific implementation to wait for and read input from stdin.
|
||||
|
||||
Args:
|
||||
loop: The asyncio event loop
|
||||
stdin_fd: The file descriptor for stdin
|
||||
|
||||
Returns:
|
||||
tuple[bytes | None, str]: A tuple containing:
|
||||
- The data read from stdin or None
|
||||
- Status: 'eof' if EOF detected, 'data' if data available, 'wait' if no data yet
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, we can't use select for stdin
|
||||
# Instead, we'll use a brief sleep and then try to read
|
||||
await asyncio.sleep(0.01)
|
||||
try:
|
||||
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
|
||||
if not chunk: # Empty bytes means EOF
|
||||
return None, STATUS_EOF
|
||||
return chunk, STATUS_DATA
|
||||
except (BlockingIOError, OSError):
|
||||
# No data available yet
|
||||
return None, STATUS_WAIT
|
||||
else:
|
||||
# On Unix-like systems, use select
|
||||
ready, _, _ = await loop.run_in_executor(
|
||||
None, lambda: select.select([stdin_fd], [], [], 0.1)
|
||||
)
|
||||
|
||||
if not ready:
|
||||
# No input available, yield to other tasks
|
||||
await asyncio.sleep(0.01)
|
||||
return None, STATUS_WAIT
|
||||
|
||||
# Read available data
|
||||
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
|
||||
if not chunk: # Empty bytes means EOF
|
||||
return None, STATUS_EOF
|
||||
return chunk, STATUS_DATA
|
||||
|
||||
|
||||
async def run_stdio_input_loop(
|
||||
session_id: str,
|
||||
mcp_process: subprocess.Popen,
|
||||
stdout_task: asyncio.Task,
|
||||
stderr_task: asyncio.Task,
|
||||
) -> None:
|
||||
"""Handle standard input, intercept call and forward requests to mcp_process stdin."""
|
||||
loop = asyncio.get_event_loop()
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
buffer = b""
|
||||
|
||||
# Set stdin to non-blocking mode
|
||||
os.set_blocking(stdin_fd, False)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get input using platform-specific method
|
||||
chunk, status = await wait_for_stdin_input(loop, stdin_fd)
|
||||
|
||||
if status == STATUS_EOF:
|
||||
# EOF detected, break the loop
|
||||
break
|
||||
elif status == STATUS_WAIT:
|
||||
# No data available yet, continue polling
|
||||
continue
|
||||
elif status == STATUS_DATA:
|
||||
# We got some data, process it
|
||||
buffer += chunk
|
||||
|
||||
# Process complete lines
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
|
||||
await _intercept_request(session_id, mcp_process, line)
|
||||
except (BrokenPipeError, KeyboardInterrupt):
|
||||
# Broken pipe = client disappeared, just start shutdown
|
||||
mcp_log("Client disconnected or keyboard interrupt")
|
||||
finally:
|
||||
# Close stdin
|
||||
if mcp_process.stdin:
|
||||
mcp_process.stdin.close()
|
||||
|
||||
# Process any remaining data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if line:
|
||||
await _intercept_request(session_id, mcp_process, line)
|
||||
|
||||
# Terminate process if needed
|
||||
if mcp_process.poll() is None:
|
||||
mcp_process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, mcp_process.wait), timeout=2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
mcp_process.kill()
|
||||
|
||||
# Cancel I/O tasks
|
||||
stdout_task.cancel()
|
||||
stderr_task.cancel()
|
||||
|
||||
# Final flush
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Splits CLI arguments into two parts:
|
||||
1. Arguments intended for the MCP gateway (everything before `--exec`)
|
||||
2. Arguments for the underlying MCP server (everything after `--exec`)
|
||||
|
||||
Parameters:
|
||||
args (list[str]): The list of CLI arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[list[str], list[str]]: A tuple containing (mcp_gateway_args, mcp_server_command_args)
|
||||
"""
|
||||
if not args:
|
||||
mcp_log("[ERROR] No arguments provided.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
exec_index = args.index("--exec")
|
||||
except ValueError:
|
||||
mcp_log("[ERROR] '--exec' flag not found in arguments.")
|
||||
sys.exit(1)
|
||||
|
||||
mcp_gateway_args = args[:exec_index]
|
||||
mcp_server_command_args = args[exec_index + 1 :]
|
||||
|
||||
if not mcp_server_command_args:
|
||||
mcp_log("[ERROR] No arguments provided after '--exec'.")
|
||||
sys.exit(1)
|
||||
|
||||
return mcp_gateway_args, mcp_server_command_args
|
||||
|
||||
|
||||
async def execute(args: list[str] = None):
|
||||
"""Main function to execute the MCP gateway."""
|
||||
if "INVARIANT_API_KEY" not in os.environ:
|
||||
mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.")
|
||||
sys.exit(1)
|
||||
|
||||
mcp_log("[INFO] Running with Python version:", sys.version)
|
||||
|
||||
mcp_gateway_args, mcp_server_command_args = split_args(args)
|
||||
session_id = generate_session_id()
|
||||
await session_store.initialize_session(
|
||||
session_id,
|
||||
McpAttributes.from_cli_args(mcp_gateway_args),
|
||||
)
|
||||
|
||||
mcp_process = subprocess.Popen(
|
||||
mcp_server_command_args,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=0,
|
||||
)
|
||||
|
||||
# Start async tasks for stdout and stderr
|
||||
stdout_task = asyncio.create_task(
|
||||
stream_and_forward_stdout(session_id, mcp_process)
|
||||
)
|
||||
stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process))
|
||||
|
||||
# Handle forwarding stdin and intercept tool calls
|
||||
await run_stdio_input_loop(session_id, mcp_process, stdout_task, stderr_task)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
MCP Transport Strategy Pattern Implementation
|
||||
|
||||
This module defines an abstract base class for MCP transports.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Tuple
|
||||
|
||||
from gateway.common.constants import (
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
)
|
||||
from gateway.mcp.mcp_sessions_manager import McpSessionsManager
|
||||
from gateway.mcp.utils import (
|
||||
hook_tool_call,
|
||||
intercept_response,
|
||||
update_mcp_server_in_session_metadata,
|
||||
update_session_from_request,
|
||||
)
|
||||
|
||||
|
||||
class MCPTransportBase(ABC):
|
||||
"""
|
||||
Abstract base class for MCP transport strategies.
|
||||
|
||||
This class defines the common interface and shared functionality for all MCP transports,
|
||||
using the Template Method pattern for request/response processing.
|
||||
"""
|
||||
|
||||
def __init__(self, session_store: McpSessionsManager):
|
||||
self.session_store = session_store
|
||||
|
||||
async def process_outgoing_request(
|
||||
self, session_id: str, request_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
"""
|
||||
Template method for processing outgoing requests to MCP server.
|
||||
|
||||
Returns:
|
||||
Tuple[processed_request_data, is_blocked]
|
||||
"""
|
||||
# Update session with request information
|
||||
session = self.session_store.get_session(session_id)
|
||||
update_session_from_request(session, request_data)
|
||||
|
||||
# Refresh guardrails
|
||||
await session.load_guardrails()
|
||||
|
||||
# Check if request should be intercepted for guardrails
|
||||
if self._should_intercept_request(request_data):
|
||||
return await self._intercept_outgoing_request(session_id, request_data)
|
||||
|
||||
return request_data, False
|
||||
|
||||
async def process_incoming_response(
|
||||
self, session_id: str, response_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
"""
|
||||
Template method for processing incoming responses from MCP server.
|
||||
|
||||
Returns:
|
||||
Tuple[processed_response, is_blocked]
|
||||
"""
|
||||
# Update session with server information
|
||||
session = self.session_store.get_session(session_id)
|
||||
update_mcp_server_in_session_metadata(session, response_data)
|
||||
|
||||
# Intercept and apply guardrails to response
|
||||
return await intercept_response(session_id, self.session_store, response_data)
|
||||
|
||||
def _should_intercept_request(self, request_data: dict[str, Any]) -> bool:
|
||||
"""Check if request should be intercepted for guardrails."""
|
||||
method = request_data.get(MCP_METHOD)
|
||||
return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS]
|
||||
|
||||
async def _intercept_outgoing_request(
|
||||
self, session_id: str, request_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
"""Common request interception logic for guardrails."""
|
||||
method = request_data.get(MCP_METHOD)
|
||||
|
||||
interception_result = request_data
|
||||
is_blocked = False
|
||||
if method == MCP_TOOL_CALL:
|
||||
interception_result, is_blocked = await hook_tool_call(
|
||||
session_id, self.session_store, request_data
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
interception_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=self.session_store,
|
||||
request_body={
|
||||
"id": request_data.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
|
||||
},
|
||||
)
|
||||
|
||||
return interception_result, is_blocked
|
||||
|
||||
def _is_initialization_request(self, request_data: dict[str, Any]) -> bool:
|
||||
"""Check if request is an initialization request."""
|
||||
return (
|
||||
request_data.get("method") in ["initialize", "notifications/initialized"]
|
||||
and "jsonrpc" in request_data
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_session(self, *args, **kwargs) -> str:
|
||||
"""Initialize a session for this transport type."""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_communication(self, *args, **kwargs) -> Any:
|
||||
"""Handle the main communication for this transport."""
|
||||
@@ -17,11 +17,11 @@ from gateway.common.constants import (
|
||||
MCP_LIST_TOOLS,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
from gateway.mcp.utils import (
|
||||
get_mcp_server_base_url,
|
||||
hook_tool_call,
|
||||
intercept_response,
|
||||
@@ -0,0 +1,314 @@
|
||||
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import select
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gateway.common.constants import (
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
McpAttributes,
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.mcp.utils import (
|
||||
generate_session_id,
|
||||
)
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
|
||||
STATUS_EOF = "eof"
|
||||
STATUS_DATA = "data"
|
||||
STATUS_WAIT = "wait"
|
||||
mcp_sessions_manager = McpSessionsManager()
|
||||
|
||||
|
||||
class StdioTransport(MCPTransportBase):
|
||||
"""
|
||||
STDIO transport implementation for MCP communication.
|
||||
Handles subprocess-based communication with stdin/stdout/stderr.
|
||||
"""
|
||||
|
||||
def __init__(self, session_store: McpSessionsManager):
|
||||
super().__init__(session_store)
|
||||
self.mcp_process: subprocess.Popen = None
|
||||
|
||||
async def initialize_session(self, *args, **kwargs) -> str:
|
||||
"""Initialize session for stdio transport."""
|
||||
session_attributes: McpAttributes = kwargs.get("session_attributes")
|
||||
session_id = generate_session_id()
|
||||
await self.session_store.initialize_session(session_id, session_attributes)
|
||||
mcp_log(f"Created stdio session with ID: {session_id}")
|
||||
return session_id
|
||||
|
||||
def start_mcp_process(self, mcp_server_command_args: list) -> subprocess.Popen:
|
||||
"""Start the MCP server subprocess."""
|
||||
self.mcp_process = subprocess.Popen(
|
||||
mcp_server_command_args,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=0,
|
||||
)
|
||||
mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}")
|
||||
return self.mcp_process
|
||||
|
||||
async def handle_communication(self, *args, **kwargs) -> None:
|
||||
"""Handle stdio communication loop."""
|
||||
session_id: str = kwargs.get("session_id")
|
||||
mcp_process: subprocess.Popen = kwargs.get("mcp_process")
|
||||
if not session_id or not mcp_process:
|
||||
raise ValueError(
|
||||
"session_id and mcp_process are required for stdio transport"
|
||||
)
|
||||
|
||||
self.mcp_process = mcp_process
|
||||
|
||||
# Start async tasks for stdout and stderr
|
||||
stdout_task = asyncio.create_task(self._stream_and_forward_stdout(session_id))
|
||||
stderr_task = asyncio.create_task(self._stream_and_forward_stderr())
|
||||
|
||||
try:
|
||||
# Handle stdin input loop
|
||||
await self._run_stdio_input_loop(session_id)
|
||||
finally:
|
||||
# Cleanup
|
||||
if self.mcp_process and self.mcp_process.stdin:
|
||||
self.mcp_process.stdin.close()
|
||||
|
||||
# Terminate process if needed
|
||||
if self.mcp_process and self.mcp_process.poll() is None:
|
||||
self.mcp_process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, self.mcp_process.wait
|
||||
),
|
||||
timeout=2,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.mcp_process.kill()
|
||||
|
||||
# Cancel I/O tasks
|
||||
stdout_task.cancel()
|
||||
stderr_task.cancel()
|
||||
|
||||
# Final flush
|
||||
sys.stdout.flush()
|
||||
|
||||
async def _stream_and_forward_stdout(self, session_id: str) -> None:
|
||||
"""Read from MCP process stdout, apply guardrails and forward to sys.stdout."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while True:
|
||||
if self.mcp_process.poll() is not None:
|
||||
mcp_log(
|
||||
f"[ERROR] MCP process terminated with code: {self.mcp_process.poll()}"
|
||||
)
|
||||
break
|
||||
|
||||
line = await loop.run_in_executor(None, self.mcp_process.stdout.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
try:
|
||||
decoded_line = line.decode(UTF_8).strip()
|
||||
if not decoded_line:
|
||||
continue
|
||||
|
||||
session = self.session_store.get_session(session_id)
|
||||
if session.attributes.verbose:
|
||||
mcp_log(f"[INFO] server -> client: {decoded_line}")
|
||||
|
||||
response_body = json.loads(decoded_line)
|
||||
processed_response, _ = await self.process_incoming_response(
|
||||
session_id, response_body
|
||||
)
|
||||
|
||||
sys.stdout.buffer.write(self._serialize_to_bytes(processed_response))
|
||||
sys.stdout.buffer.flush()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
mcp_log(f"[ERROR] Error in _stream_and_forward_stdout: {str(e)}")
|
||||
if line:
|
||||
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
|
||||
|
||||
async def _stream_and_forward_stderr(self) -> None:
|
||||
"""Read from MCP process stderr and write to log file."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while True:
|
||||
chunk = await loop.run_in_executor(
|
||||
None, lambda: self.mcp_process.stderr.read(10)
|
||||
)
|
||||
if not chunk:
|
||||
break
|
||||
MCP_LOG_FILE.buffer.write(chunk)
|
||||
MCP_LOG_FILE.buffer.flush()
|
||||
|
||||
async def _run_stdio_input_loop(self, session_id: str) -> None:
|
||||
"""Handle standard input, intercept calls and forward requests to MCP process stdin."""
|
||||
loop = asyncio.get_event_loop()
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
buffer = b""
|
||||
|
||||
# Set stdin to non-blocking mode
|
||||
os.set_blocking(stdin_fd, False)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get input using platform-specific method
|
||||
chunk, status = await self._wait_for_stdin_input(loop, stdin_fd)
|
||||
|
||||
if status == STATUS_EOF:
|
||||
break
|
||||
elif status == STATUS_WAIT:
|
||||
continue
|
||||
elif status == STATUS_DATA:
|
||||
buffer += chunk
|
||||
|
||||
# Process complete lines
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
|
||||
await self._process_stdin_line(session_id, line)
|
||||
|
||||
except (BrokenPipeError, KeyboardInterrupt):
|
||||
mcp_log("Client disconnected or keyboard interrupt")
|
||||
finally:
|
||||
# Process any remaining data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if line:
|
||||
await self._process_stdin_line(session_id, line)
|
||||
|
||||
async def _process_stdin_line(self, session_id: str, line: bytes) -> None:
|
||||
"""Process a line of input from stdin."""
|
||||
session = self.session_store.get_session(session_id)
|
||||
if session.attributes.verbose:
|
||||
mcp_log(f"[INFO] client -> server: {line}")
|
||||
|
||||
try:
|
||||
text = line.decode(UTF_8)
|
||||
request_body = json.loads(text)
|
||||
except json.JSONDecodeError as je:
|
||||
mcp_log(f"[ERROR] JSON decode error: {str(je)}")
|
||||
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
|
||||
return
|
||||
|
||||
processed_request, is_blocked = await self.process_outgoing_request(
|
||||
session_id, request_body
|
||||
)
|
||||
|
||||
if is_blocked:
|
||||
sys.stdout.buffer.write(self._serialize_to_bytes(processed_request))
|
||||
sys.stdout.buffer.flush()
|
||||
return
|
||||
self.mcp_process.stdin.write(self._serialize_to_bytes(request_body))
|
||||
self.mcp_process.stdin.flush()
|
||||
|
||||
async def _wait_for_stdin_input(
|
||||
self, loop: asyncio.AbstractEventLoop, stdin_fd: int
|
||||
) -> Tuple[Optional[bytes], str]:
|
||||
"""Platform-specific implementation to wait for and read input from stdin."""
|
||||
if platform.system() == "Windows":
|
||||
await asyncio.sleep(0.01)
|
||||
try:
|
||||
chunk = await loop.run_in_executor(
|
||||
None, lambda: os.read(stdin_fd, 4096)
|
||||
)
|
||||
if not chunk:
|
||||
return None, STATUS_EOF
|
||||
return chunk, STATUS_DATA
|
||||
except (BlockingIOError, OSError):
|
||||
return None, STATUS_WAIT
|
||||
else:
|
||||
# Unix-like systems
|
||||
ready, _, _ = await loop.run_in_executor(
|
||||
None, lambda: select.select([stdin_fd], [], [], 0.1)
|
||||
)
|
||||
|
||||
if not ready:
|
||||
await asyncio.sleep(0.01)
|
||||
return None, STATUS_WAIT
|
||||
|
||||
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
|
||||
if not chunk:
|
||||
return None, STATUS_EOF
|
||||
return chunk, STATUS_DATA
|
||||
|
||||
def _serialize_to_bytes(self, data: dict) -> bytes:
|
||||
"""Serialize dict to bytes using UTF-8 encoding."""
|
||||
return json.dumps(data).encode(UTF_8) + b"\n"
|
||||
|
||||
|
||||
async def create_stdio_transport_and_execute(
|
||||
session_store: McpSessionsManager,
|
||||
session_attributes: McpAttributes,
|
||||
mcp_server_command_args: list,
|
||||
) -> None:
|
||||
"""Integration function for stdio execution."""
|
||||
stdio_transport = StdioTransport(session_store=session_store)
|
||||
|
||||
session_id = await stdio_transport.initialize_session(
|
||||
session_attributes=session_attributes
|
||||
)
|
||||
|
||||
await stdio_transport.handle_communication(
|
||||
session_id=session_id,
|
||||
mcp_process=stdio_transport.start_mcp_process(mcp_server_command_args),
|
||||
)
|
||||
|
||||
|
||||
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Splits CLI arguments into two parts:
|
||||
1. Arguments intended for the MCP gateway (everything before `--exec`)
|
||||
2. Arguments for the underlying MCP server (everything after `--exec`)
|
||||
"""
|
||||
if not args:
|
||||
mcp_log("[ERROR] No arguments provided.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
exec_index = args.index("--exec")
|
||||
except ValueError:
|
||||
mcp_log("[ERROR] '--exec' flag not found in arguments.")
|
||||
sys.exit(1)
|
||||
|
||||
mcp_gateway_args = args[:exec_index]
|
||||
mcp_server_command_args = args[exec_index + 1 :]
|
||||
|
||||
if not mcp_server_command_args:
|
||||
mcp_log("[ERROR] No arguments provided after '--exec'.")
|
||||
sys.exit(1)
|
||||
|
||||
return mcp_gateway_args, mcp_server_command_args
|
||||
|
||||
|
||||
async def execute(args: list[str] = None):
|
||||
"""Main function to execute the MCP gateway using transport strategy."""
|
||||
if "INVARIANT_API_KEY" not in os.environ:
|
||||
mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.")
|
||||
sys.exit(1)
|
||||
|
||||
mcp_log("[INFO] Running with Python version:", sys.version)
|
||||
|
||||
# Parse arguments
|
||||
mcp_gateway_args, mcp_server_command_args = split_args(args)
|
||||
|
||||
# Create session store and attributes
|
||||
session_attributes = McpAttributes.from_cli_args(mcp_gateway_args)
|
||||
|
||||
# Use stdio transport strategy
|
||||
await create_stdio_transport_and_execute(
|
||||
session_store=mcp_sessions_manager,
|
||||
session_attributes=session_attributes,
|
||||
mcp_server_command_args=mcp_server_command_args,
|
||||
)
|
||||
@@ -15,11 +15,11 @@ from gateway.common.constants import (
|
||||
MCP_TOOL_CALL,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
from gateway.mcp.utils import (
|
||||
generate_session_id,
|
||||
get_mcp_server_base_url,
|
||||
hook_tool_call,
|
||||
@@ -22,7 +22,7 @@ from gateway.common.constants import (
|
||||
MCP_TOOL_CALL,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSession,
|
||||
McpSessionsManager,
|
||||
)
|
||||
+2
-2
@@ -7,8 +7,8 @@ from starlette_compress import CompressMiddleware
|
||||
from gateway.routes.anthropic import gateway as anthropic_gateway
|
||||
from gateway.routes.gemini import gateway as gemini_gateway
|
||||
from gateway.routes.open_ai import gateway as open_ai_gateway
|
||||
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
|
||||
from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway
|
||||
from gateway.mcp.sse import gateway as mcp_sse_gateway
|
||||
from gateway.mcp.streamable import gateway as mcp_streamable_gateway
|
||||
|
||||
app = fastapi.app = fastapi.FastAPI(
|
||||
docs_url="/api/v1/gateway/docs",
|
||||
|
||||
Reference in New Issue
Block a user