Small cleanups in MCP related code.

This commit is contained in:
Hemang
2025-06-04 13:59:41 +02:00
parent f3b5e2d0b1
commit 9053d89f25
5 changed files with 50 additions and 96 deletions

View File

@@ -4,7 +4,6 @@ MCP Transport Strategy Pattern Implementation
This module defines an abstract base class for MCP transports.
"""
import asyncio
import json
import re
import uuid
@@ -31,7 +30,7 @@ from gateway.mcp.log import format_errors_in_response
from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager
class MCPTransportBase(ABC):
class McpTransportBase(ABC):
"""
Abstract base class for MCP transport strategies.
@@ -53,7 +52,7 @@ class MCPTransportBase(ABC):
"""
# Update session with request information
session = self.session_store.get_session(session_id)
MCPTransportBase.update_session_from_request(session, request_data)
McpTransportBase.update_session_from_request(session, request_data)
# Refresh guardrails
await session.load_guardrails()
@@ -75,10 +74,10 @@ class MCPTransportBase(ABC):
"""
# Update session with server information
session = self.session_store.get_session(session_id)
MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data)
McpTransportBase.update_mcp_server_in_session_metadata(session, response_data)
# Intercept and apply guardrails to response
return await MCPTransportBase.intercept_response(
return await McpTransportBase.intercept_response(
session_id, self.session_store, response_data
)
@@ -87,6 +86,17 @@ class MCPTransportBase(ABC):
method = request_data.get(MCP_METHOD)
return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS]
@staticmethod
def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict:
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": message,
},
}
async def _intercept_outgoing_request(
self, session_id: str, request_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
@@ -96,11 +106,11 @@ class MCPTransportBase(ABC):
interception_result = request_data
is_blocked = False
if method == MCP_TOOL_CALL:
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
interception_result, is_blocked = await McpTransportBase.hook_tool_call(
session_id, self.session_store, request_data
)
elif method == MCP_LIST_TOOLS:
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
interception_result, is_blocked = await McpTransportBase.hook_tool_call(
session_id=session_id,
session_store=self.session_store,
request_body={
@@ -152,8 +162,8 @@ class MCPTransportBase(ABC):
@staticmethod
def update_session_from_request(session: McpSession, request_body: dict) -> None:
"""Update the MCP client information and request id in the session."""
MCPTransportBase.update_mcp_client_info_in_session(session, request_body)
MCPTransportBase.update_tool_call_id_in_session(session, request_body)
McpTransportBase.update_mcp_client_info_in_session(session, request_body)
McpTransportBase.update_tool_call_id_in_session(session, request_body)
@staticmethod
def get_mcp_server_base_url(request: Request) -> str:
@@ -164,7 +174,7 @@ class MCPTransportBase(ABC):
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return MCPTransportBase.convert_localhost_to_docker_host(
return McpTransportBase.convert_localhost_to_docker_host(
mcp_server_base_url
).rstrip("/")
@@ -233,7 +243,7 @@ class MCPTransportBase(ABC):
if (
guardrails_result
and guardrails_result.get("errors", [])
and MCPTransportBase.check_if_new_errors(
and McpTransportBase.check_if_new_errors(
session_id, session_store, guardrails_result
)
):
@@ -243,15 +253,10 @@ class MCPTransportBase(ABC):
message=message,
guardrails_result=guardrails_result,
)
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
return McpTransportBase._create_jsonrpc_error_response(
request_body,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
), True
# Push trace to the explorer
await session_store.add_message_to_session(
@@ -298,22 +303,17 @@ class MCPTransportBase(ABC):
if (
guardrails_result
and guardrails_result.get("errors", [])
and MCPTransportBase.check_if_new_errors(
and McpTransportBase.check_if_new_errors(
session_id, session_store, guardrails_result
)
):
is_blocked = True
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
result = McpTransportBase._create_jsonrpc_error_response(
response_body,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
)
else:
# Special error response for tools/list
result = {
@@ -379,7 +379,7 @@ class MCPTransportBase(ABC):
(
intercept_response_result,
is_blocked,
) = await MCPTransportBase.hook_tool_call_response(
) = await McpTransportBase.hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
@@ -393,7 +393,7 @@ class MCPTransportBase(ABC):
(
intercept_response_result,
is_blocked,
) = await MCPTransportBase.hook_tool_call_response(
) = await McpTransportBase.hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body={
@@ -410,9 +410,9 @@ class MCPTransportBase(ABC):
return intercept_response_result, is_blocked
@abstractmethod
async def initialize_session(self, *args, **kwargs) -> str:
async def initialize_session(self, **kwargs) -> str:
"""Initialize a session for this transport type."""
@abstractmethod
async def handle_communication(self, *args, **kwargs) -> Any:
async def handle_communication(self, **kwargs) -> Any:
"""Handle the main communication for this transport."""

View File

@@ -16,7 +16,7 @@ from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_transport_base import McpTransportBase
MCP_SERVER_POST_HEADERS = {
"connection",
@@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post(
raise HTTPException(status_code=400, detail="Session does not exist")
request_body = json.loads(await request.body())
return await SSETransport(session_store).handle_post_request(
return await SseTransport(session_store).handle_post_request(
request, session_id, request_body
)
@@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream(
request: Request, session_store: McpSessionsManager
) -> StreamingResponse:
"""Integration function for SSE GET route."""
return await SSETransport(session_store).handle_sse_stream(request)
return await SseTransport(session_store).handle_sse_stream(request)
class SSETransport(MCPTransportBase):
class SseTransport(McpTransportBase):
"""
Server-Sent Events transport implementation for MCP communication.
Handles HTTP-based SSE communication with message queuing.
@@ -82,7 +82,6 @@ class SSETransport(MCPTransportBase):
async def initialize_session(
self,
*args,
**kwargs,
) -> str:
"""Initialize or get existing SSE session."""
@@ -298,7 +297,7 @@ class SSETransport(MCPTransportBase):
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
async def handle_communication(self, *args, **kwargs) -> StreamingResponse:
async def handle_communication(self, **kwargs) -> StreamingResponse:
"""Main communication handler for SSE transport."""
return await self.handle_sse_stream(kwargs.get("request"))

View File

@@ -15,7 +15,7 @@ from gateway.mcp.mcp_sessions_manager import (
McpAttributes,
McpSessionsManager,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_transport_base import McpTransportBase
STATUS_EOF = "eof"
STATUS_DATA = "data"
@@ -23,7 +23,7 @@ STATUS_WAIT = "wait"
mcp_sessions_manager = McpSessionsManager()
class StdioTransport(MCPTransportBase):
class StdioTransport(McpTransportBase):
"""
STDIO transport implementation for MCP communication.
Handles subprocess-based communication with stdin/stdout/stderr.
@@ -33,7 +33,7 @@ class StdioTransport(MCPTransportBase):
super().__init__(session_store)
self.mcp_process: subprocess.Popen = None
async def initialize_session(self, *args, **kwargs) -> str:
async def initialize_session(self, **kwargs) -> str:
"""Initialize session for stdio transport."""
session_attributes: McpAttributes = kwargs.get("session_attributes")
session_id = self.generate_session_id()
@@ -53,7 +53,7 @@ class StdioTransport(MCPTransportBase):
mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}")
return self.mcp_process
async def handle_communication(self, *args, **kwargs) -> None:
async def handle_communication(self, **kwargs) -> None:
"""Handle stdio communication loop."""
session_id: str = kwargs.get("session_id")
mcp_process: subprocess.Popen = kwargs.get("mcp_process")

View File

@@ -1,7 +1,7 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
import json
from typing import Any, Optional, Union
from typing import Any, Optional
import httpx
from httpx_sse import aconnect_sse
@@ -18,7 +18,7 @@ from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_transport_base import McpTransportBase
gateway = APIRouter()
mcp_sessions_manager = McpSessionsManager()
@@ -69,7 +69,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response:
async def create_streamable_transport_and_handle_request(
request: Request, method: str, session_store: McpSessionsManager
) -> Union[Response, StreamingResponse]:
) -> Response | StreamingResponse:
"""Integration function for streamable routes."""
streamable_transport = StreamableTransport(session_store)
return await streamable_transport.handle_communication(
@@ -77,7 +77,7 @@ async def create_streamable_transport_and_handle_request(
)
class StreamableTransport(MCPTransportBase):
class StreamableTransport(McpTransportBase):
"""
Streamable HTTP transport implementation for MCP communication.
Handles HTTP POST/GET/DELETE requests with JSON and streaming responses.
@@ -85,7 +85,6 @@ class StreamableTransport(MCPTransportBase):
async def initialize_session(
self,
*args,
**kwargs,
) -> str:
"""Initialize streamable HTTP session."""
@@ -111,7 +110,7 @@ class StreamableTransport(MCPTransportBase):
async def handle_post_request(
self, request: Request, request_body: dict[str, Any]
) -> Union[Response, StreamingResponse]:
) -> Response | StreamingResponse:
"""Handle POST request to streamable endpoint."""
session_attributes = McpAttributes.from_request_headers(request.headers)
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
@@ -222,9 +221,7 @@ class StreamableTransport(MCPTransportBase):
print(f"[MCP DELETE] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
async def handle_communication(
self, *args, **kwargs
) -> Union[Response, StreamingResponse]:
async def handle_communication(self, **kwargs) -> Response | StreamingResponse:
"""Main communication handler for streamable transport."""
request = kwargs.get("request")
method = kwargs.get("method", "POST")
@@ -262,7 +259,7 @@ class StreamableTransport(MCPTransportBase):
session_id: str,
session_attributes: McpAttributes,
is_initialization_request: bool,
) -> Union[Response, StreamingResponse]:
) -> Response | StreamingResponse:
"""Forward request to MCP server and handle response."""
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:

View File

@@ -1,42 +0,0 @@
"""Task utilities for running async functions"""
import asyncio
import concurrent.futures
from contextlib import redirect_stdout
from typing import Any
from gateway.mcp.log import MCP_LOG_FILE
def run_task_sync(async_func, *args, **kwargs) -> Any:
"""
Runs an asynchronous function synchronously in a separate
thread with its own event loop. This function blocks the calling
thread until completion or timeout (10 seconds).
Args:
async_func: The async function to run
*args: Positional arguments to pass to the async function
**kwargs: Keyword arguments to pass to the async function
Returns:
Any: The return value of the async function
"""
def run_in_new_loop():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(
async_func(
*args,
**kwargs,
)
)
finally:
loop.close()
with redirect_stdout(MCP_LOG_FILE):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result(timeout=10.0)