mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Small cleanups in MCP related code.
This commit is contained in:
@@ -4,7 +4,6 @@ MCP Transport Strategy Pattern Implementation
|
||||
This module defines an abstract base class for MCP transports.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
@@ -31,7 +30,7 @@ from gateway.mcp.log import format_errors_in_response
|
||||
from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager
|
||||
|
||||
|
||||
class MCPTransportBase(ABC):
|
||||
class McpTransportBase(ABC):
|
||||
"""
|
||||
Abstract base class for MCP transport strategies.
|
||||
|
||||
@@ -53,7 +52,7 @@ class MCPTransportBase(ABC):
|
||||
"""
|
||||
# Update session with request information
|
||||
session = self.session_store.get_session(session_id)
|
||||
MCPTransportBase.update_session_from_request(session, request_data)
|
||||
McpTransportBase.update_session_from_request(session, request_data)
|
||||
|
||||
# Refresh guardrails
|
||||
await session.load_guardrails()
|
||||
@@ -75,10 +74,10 @@ class MCPTransportBase(ABC):
|
||||
"""
|
||||
# Update session with server information
|
||||
session = self.session_store.get_session(session_id)
|
||||
MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data)
|
||||
McpTransportBase.update_mcp_server_in_session_metadata(session, response_data)
|
||||
|
||||
# Intercept and apply guardrails to response
|
||||
return await MCPTransportBase.intercept_response(
|
||||
return await McpTransportBase.intercept_response(
|
||||
session_id, self.session_store, response_data
|
||||
)
|
||||
|
||||
@@ -87,6 +86,17 @@ class MCPTransportBase(ABC):
|
||||
method = request_data.get(MCP_METHOD)
|
||||
return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS]
|
||||
|
||||
@staticmethod
|
||||
def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
|
||||
async def _intercept_outgoing_request(
|
||||
self, session_id: str, request_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
@@ -96,11 +106,11 @@ class MCPTransportBase(ABC):
|
||||
interception_result = request_data
|
||||
is_blocked = False
|
||||
if method == MCP_TOOL_CALL:
|
||||
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
|
||||
interception_result, is_blocked = await McpTransportBase.hook_tool_call(
|
||||
session_id, self.session_store, request_data
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
|
||||
interception_result, is_blocked = await McpTransportBase.hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=self.session_store,
|
||||
request_body={
|
||||
@@ -152,8 +162,8 @@ class MCPTransportBase(ABC):
|
||||
@staticmethod
|
||||
def update_session_from_request(session: McpSession, request_body: dict) -> None:
|
||||
"""Update the MCP client information and request id in the session."""
|
||||
MCPTransportBase.update_mcp_client_info_in_session(session, request_body)
|
||||
MCPTransportBase.update_tool_call_id_in_session(session, request_body)
|
||||
McpTransportBase.update_mcp_client_info_in_session(session, request_body)
|
||||
McpTransportBase.update_tool_call_id_in_session(session, request_body)
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_server_base_url(request: Request) -> str:
|
||||
@@ -164,7 +174,7 @@ class MCPTransportBase(ABC):
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
return MCPTransportBase.convert_localhost_to_docker_host(
|
||||
return McpTransportBase.convert_localhost_to_docker_host(
|
||||
mcp_server_base_url
|
||||
).rstrip("/")
|
||||
|
||||
@@ -233,7 +243,7 @@ class MCPTransportBase(ABC):
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and MCPTransportBase.check_if_new_errors(
|
||||
and McpTransportBase.check_if_new_errors(
|
||||
session_id, session_store, guardrails_result
|
||||
)
|
||||
):
|
||||
@@ -243,15 +253,10 @@ class MCPTransportBase(ABC):
|
||||
message=message,
|
||||
guardrails_result=guardrails_result,
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}, True
|
||||
return McpTransportBase._create_jsonrpc_error_response(
|
||||
request_body,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
|
||||
), True
|
||||
|
||||
# Push trace to the explorer
|
||||
await session_store.add_message_to_session(
|
||||
@@ -298,22 +303,17 @@ class MCPTransportBase(ABC):
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and MCPTransportBase.check_if_new_errors(
|
||||
and McpTransportBase.check_if_new_errors(
|
||||
session_id, session_store, guardrails_result
|
||||
)
|
||||
):
|
||||
is_blocked = True
|
||||
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
result = McpTransportBase._create_jsonrpc_error_response(
|
||||
response_body,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
|
||||
)
|
||||
else:
|
||||
# Special error response for tools/list
|
||||
result = {
|
||||
@@ -379,7 +379,7 @@ class MCPTransportBase(ABC):
|
||||
(
|
||||
intercept_response_result,
|
||||
is_blocked,
|
||||
) = await MCPTransportBase.hook_tool_call_response(
|
||||
) = await McpTransportBase.hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body=response_body,
|
||||
@@ -393,7 +393,7 @@ class MCPTransportBase(ABC):
|
||||
(
|
||||
intercept_response_result,
|
||||
is_blocked,
|
||||
) = await MCPTransportBase.hook_tool_call_response(
|
||||
) = await McpTransportBase.hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body={
|
||||
@@ -410,9 +410,9 @@ class MCPTransportBase(ABC):
|
||||
return intercept_response_result, is_blocked
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_session(self, *args, **kwargs) -> str:
|
||||
async def initialize_session(self, **kwargs) -> str:
|
||||
"""Initialize a session for this transport type."""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_communication(self, *args, **kwargs) -> Any:
|
||||
async def handle_communication(self, **kwargs) -> Any:
|
||||
"""Handle the main communication for this transport."""
|
||||
|
||||
@@ -16,7 +16,7 @@ from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_transport_base import McpTransportBase
|
||||
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
"connection",
|
||||
@@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post(
|
||||
raise HTTPException(status_code=400, detail="Session does not exist")
|
||||
|
||||
request_body = json.loads(await request.body())
|
||||
return await SSETransport(session_store).handle_post_request(
|
||||
return await SseTransport(session_store).handle_post_request(
|
||||
request, session_id, request_body
|
||||
)
|
||||
|
||||
@@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream(
|
||||
request: Request, session_store: McpSessionsManager
|
||||
) -> StreamingResponse:
|
||||
"""Integration function for SSE GET route."""
|
||||
return await SSETransport(session_store).handle_sse_stream(request)
|
||||
return await SseTransport(session_store).handle_sse_stream(request)
|
||||
|
||||
|
||||
class SSETransport(MCPTransportBase):
|
||||
class SseTransport(McpTransportBase):
|
||||
"""
|
||||
Server-Sent Events transport implementation for MCP communication.
|
||||
Handles HTTP-based SSE communication with message queuing.
|
||||
@@ -82,7 +82,6 @@ class SSETransport(MCPTransportBase):
|
||||
|
||||
async def initialize_session(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Initialize or get existing SSE session."""
|
||||
@@ -298,7 +297,7 @@ class SSETransport(MCPTransportBase):
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
|
||||
)
|
||||
|
||||
async def handle_communication(self, *args, **kwargs) -> StreamingResponse:
|
||||
async def handle_communication(self, **kwargs) -> StreamingResponse:
|
||||
"""Main communication handler for SSE transport."""
|
||||
return await self.handle_sse_stream(kwargs.get("request"))
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from gateway.mcp.mcp_sessions_manager import (
|
||||
McpAttributes,
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_transport_base import McpTransportBase
|
||||
|
||||
STATUS_EOF = "eof"
|
||||
STATUS_DATA = "data"
|
||||
@@ -23,7 +23,7 @@ STATUS_WAIT = "wait"
|
||||
mcp_sessions_manager = McpSessionsManager()
|
||||
|
||||
|
||||
class StdioTransport(MCPTransportBase):
|
||||
class StdioTransport(McpTransportBase):
|
||||
"""
|
||||
STDIO transport implementation for MCP communication.
|
||||
Handles subprocess-based communication with stdin/stdout/stderr.
|
||||
@@ -33,7 +33,7 @@ class StdioTransport(MCPTransportBase):
|
||||
super().__init__(session_store)
|
||||
self.mcp_process: subprocess.Popen = None
|
||||
|
||||
async def initialize_session(self, *args, **kwargs) -> str:
|
||||
async def initialize_session(self, **kwargs) -> str:
|
||||
"""Initialize session for stdio transport."""
|
||||
session_attributes: McpAttributes = kwargs.get("session_attributes")
|
||||
session_id = self.generate_session_id()
|
||||
@@ -53,7 +53,7 @@ class StdioTransport(MCPTransportBase):
|
||||
mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}")
|
||||
return self.mcp_process
|
||||
|
||||
async def handle_communication(self, *args, **kwargs) -> None:
|
||||
async def handle_communication(self, **kwargs) -> None:
|
||||
"""Handle stdio communication loop."""
|
||||
session_id: str = kwargs.get("session_id")
|
||||
mcp_process: subprocess.Popen = kwargs.get("mcp_process")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from httpx_sse import aconnect_sse
|
||||
@@ -18,7 +18,7 @@ from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_transport_base import McpTransportBase
|
||||
|
||||
gateway = APIRouter()
|
||||
mcp_sessions_manager = McpSessionsManager()
|
||||
@@ -69,7 +69,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response:
|
||||
|
||||
async def create_streamable_transport_and_handle_request(
|
||||
request: Request, method: str, session_store: McpSessionsManager
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
) -> Response | StreamingResponse:
|
||||
"""Integration function for streamable routes."""
|
||||
streamable_transport = StreamableTransport(session_store)
|
||||
return await streamable_transport.handle_communication(
|
||||
@@ -77,7 +77,7 @@ async def create_streamable_transport_and_handle_request(
|
||||
)
|
||||
|
||||
|
||||
class StreamableTransport(MCPTransportBase):
|
||||
class StreamableTransport(McpTransportBase):
|
||||
"""
|
||||
Streamable HTTP transport implementation for MCP communication.
|
||||
Handles HTTP POST/GET/DELETE requests with JSON and streaming responses.
|
||||
@@ -85,7 +85,6 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
async def initialize_session(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Initialize streamable HTTP session."""
|
||||
@@ -111,7 +110,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
async def handle_post_request(
|
||||
self, request: Request, request_body: dict[str, Any]
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
) -> Response | StreamingResponse:
|
||||
"""Handle POST request to streamable endpoint."""
|
||||
session_attributes = McpAttributes.from_request_headers(request.headers)
|
||||
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
@@ -222,9 +221,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
print(f"[MCP DELETE] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
|
||||
async def handle_communication(
|
||||
self, *args, **kwargs
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
async def handle_communication(self, **kwargs) -> Response | StreamingResponse:
|
||||
"""Main communication handler for streamable transport."""
|
||||
request = kwargs.get("request")
|
||||
method = kwargs.get("method", "POST")
|
||||
@@ -262,7 +259,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
session_id: str,
|
||||
session_attributes: McpAttributes,
|
||||
is_initialization_request: bool,
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
) -> Response | StreamingResponse:
|
||||
"""Forward request to MCP server and handle response."""
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Task utilities for running async functions"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any
|
||||
|
||||
from gateway.mcp.log import MCP_LOG_FILE
|
||||
|
||||
|
||||
def run_task_sync(async_func, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Runs an asynchronous function synchronously in a separate
|
||||
thread with its own event loop. This function blocks the calling
|
||||
thread until completion or timeout (10 seconds).
|
||||
|
||||
Args:
|
||||
async_func: The async function to run
|
||||
*args: Positional arguments to pass to the async function
|
||||
**kwargs: Keyword arguments to pass to the async function
|
||||
|
||||
Returns:
|
||||
Any: The return value of the async function
|
||||
"""
|
||||
|
||||
def run_in_new_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
async_func(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with redirect_stdout(MCP_LOG_FILE):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=10.0)
|
||||
Reference in New Issue
Block a user