diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py index fe24bc8..0130324 100644 --- a/gateway/common/mcp_sessions_manager.py +++ b/gateway/common/mcp_sessions_manager.py @@ -58,6 +58,7 @@ class McpSession(BaseModel): pending_error_messages: List[dict] = Field(default_factory=list) # Lock to maintain in-order pushes to explorer + # and other session-related operations _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) async def load_guardrails(self) -> None: @@ -292,24 +293,53 @@ class McpSessionsManager: def __init__(self): self._sessions: dict[str, McpSession] = {} + # Dictionary to store per-session locks. + # Used for session initialization and deletion. + self._session_locks: dict[str, asyncio.Lock] = {} + # Global lock to protect the locks dictionary itself + self._global_lock = asyncio.Lock() def session_exists(self, session_id: str) -> bool: """Check if a session exists""" return session_id in self._sessions + async def _get_session_lock(self, session_id: str) -> asyncio.Lock: + """ + Get a lock for a specific session ID, creating one if it doesn't exist. + Uses the global lock to protect access to the locks dictionary. + """ + async with self._global_lock: + if session_id not in self._session_locks: + self._session_locks[session_id] = asyncio.Lock() + return self._session_locks[session_id] + + async def cleanup_session_lock(self, session_id: str) -> None: + """Remove a session lock when it's no longer needed""" + async with self._global_lock: + if session_id in self._session_locks: + del self._session_locks[session_id] + async def initialize_session( self, session_id: str, sse_header_attributes: SseHeaderAttributes ) -> None: """Initialize a new session""" - if session_id not in self._sessions: - session = McpSession( - session_id=session_id, - explorer_dataset=sse_header_attributes.explorer_dataset, - push_explorer=sse_header_attributes.push_explorer, - ) - self._sessions[session_id] = session - # Load guardrails for the session from the explorer - await session.load_guardrails() + # Get the lock for this specific session + session_lock = await self._get_session_lock(session_id) + + # Acquire the lock for this session + async with session_lock: + # Check again if session exists (it might have been created while waiting for the lock) + if session_id not in self._sessions: + session = McpSession( + session_id=session_id, + explorer_dataset=sse_header_attributes.explorer_dataset, + push_explorer=sse_header_attributes.push_explorer, + ) + self._sessions[session_id] = session + # Load guardrails for the session from the explorer + await session.load_guardrails() + else: + print(f"Session {session_id} already exists, skipping initialization", flush=True) def get_session(self, session_id: str) -> McpSession: """Get a session by ID""" diff --git a/gateway/common/mcp_utils.py b/gateway/common/mcp_utils.py index c6f6f8a..a1fa293 100644 --- a/gateway/common/mcp_utils.py +++ b/gateway/common/mcp_utils.py @@ -48,4 +48,4 @@ def get_mcp_server_base_url(request: Request) -> str: status_code=400, detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", ) - return _convert_localhost_to_docker_host(mcp_server_base_url) + return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/") diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index 8bcd43e..e0091d3 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -50,17 +50,11 @@ session_store = McpSessionsManager() @gateway.post("/mcp/sse/messages/") -async def mcp_post_gateway( +async def mcp_post_sse_gateway( request: Request, ) -> Response: """Proxy calls to the MCP Server tools""" query_params = dict(request.query_params) - print("[MCP POST] Query params:", query_params, flush=True) - print( - "[MCP POST] Query params session_id:", - query_params.get("session_id"), - flush=True, - ) if not query_params.get("session_id"): raise HTTPException( status_code=400, @@ -193,6 +187,9 @@ async def mcp_get_sse_gateway( status_code=event_source.response.status_code, detail=error_content, ) + response_headers.update( + dict(event_source.response.headers.items()) + ) async for sse in event_source.aiter_sse(): if sse.event == "endpoint": diff --git a/gateway/routes/mcp_streamable.py b/gateway/routes/mcp_streamable.py index f978387..e6f9ae7 100644 --- a/gateway/routes/mcp_streamable.py +++ b/gateway/routes/mcp_streamable.py @@ -1,33 +1,157 @@ """Gateway service to forward requests to the MCP Streamable HTTP servers""" +import json + +from gateway.common.constants import CLIENT_TIMEOUT +from gateway.common.mcp_sessions_manager import McpSessionsManager, SseHeaderAttributes from gateway.common.mcp_utils import get_mcp_server_base_url -from fastapi import APIRouter, Request, Response +import httpx +from httpx_sse import aconnect_sse +from fastapi import APIRouter, HTTPException, Request, Response +from fastapi.responses import StreamingResponse -MCP_SESSION_ID_HEADER = "mcp-session-id" gateway = APIRouter() +session_store = McpSessionsManager() + +MCP_SESSION_ID_HEADER = "mcp-session-id" +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" +MCP_SERVER_POST_HEADERS = { + "connection", + "accept", + "content-length", + "content-type", + MCP_SESSION_ID_HEADER, +} +MCP_SERVER_SSE_HEADERS = { + "connection", + "accept", + "cache-control", + MCP_SESSION_ID_HEADER, +} -def get_session_id(request: Request) -> str | None: +def get_session_id(request: Request) -> str: """Extract the session ID from request headers.""" - return request.headers.get(MCP_SESSION_ID_HEADER) + session_id = request.headers.get(MCP_SESSION_ID_HEADER) + if not session_id: + raise HTTPException( + status_code=400, + detail=f"Missing {MCP_SESSION_ID_HEADER} header", + ) + return session_id + + +def get_mcp_server_endpoint(request: Request) -> str: + """ + Extract the MCP server endpoint from the request headers. + """ + return get_mcp_server_base_url(request) + "/mcp/" @gateway.post("/mcp/streamable") -async def mcp_post_gateway( +async def mcp_post_streamable_gateway( request: Request, -) -> Response: +) -> StreamingResponse: """ Forward a POST request to the MCP Streamable server. """ - mcp_server_base_url = get_mcp_server_base_url(request) - pass + body = await request.body() + session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Determine if this is an initialization request, only for our session tracking + try: + raw_message = json.loads(body) + is_initialization_request = ( + isinstance(raw_message, dict) + and raw_message.get("method") == "initialize" + and "jsonrpc" in raw_message + ) + except json.JSONDecodeError: + # Let the server handle the validation error + pass + + mcp_server_endpoint = get_mcp_server_endpoint(request) + filtered_headers = { + k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_POST_HEADERS + } + if ( + session_id + and not is_initialization_request + and not session_store.session_exists(session_id) + ): + raise HTTPException(status_code=404, detail="Invalid or expired session ID") + sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers) + + async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: + try: + response = await client.post( + url=mcp_server_endpoint, + headers=filtered_headers, + content=body, + follow_redirects=True, + ) + + # If we received a session ID from server, register it in our session store + # This happens in initialization responses + resp_session_id = response.headers.get(MCP_SESSION_ID_HEADER) + if resp_session_id and not session_store.session_exists(resp_session_id): + await session_store.initialize_session( + resp_session_id, sse_header_attributes + ) + + # If the response is JSON, return it directly + if response.headers.get("content-type", "") == CONTENT_TYPE_JSON: + return Response( + content=response.content, + status_code=response.status_code, + headers={"X-Proxied-By": "mcp-gateway", **response.headers}, + ) + + # Else return SSE streaming response + async def event_generator(): + # Events have two parts: + # 1. event: {type} -> contains the type of event + # 2. data: {data} -> contains the actual message + # We are reading line by line so we need to buffer so that we can + # send the entire event (with both type and data) together. + # Once we receive an empty line, we end the stream. + buffer = "" + async for line in response.aiter_lines(): + if line.strip(): + if buffer: + complete_event = buffer + "\n" + line + "\n\n" + yield complete_event + # Clear the buffer for the next event + buffer = "" + else: + buffer = line + else: + # End stream here when line is empty. + break + + return StreamingResponse( + event_generator(), + media_type=CONTENT_TYPE_SSE, + headers={ + "X-Proxied-By": "mcp-gateway", + **response.headers, + }, + ) + + except httpx.RequestError as e: + print(f"[MCP POST] Request error: {str(e)}") + raise HTTPException(status_code=500, detail="Request error") from e + except Exception as e: + print(f"[MCP POST] Unexpected error: {str(e)}") + raise HTTPException(status_code=500, detail="Unexpected error") from e @gateway.get("/mcp/streamable") -async def mcp_get_gateway( +async def mcp_get_streamable_gateway( request: Request, ) -> Response: """ @@ -37,16 +161,91 @@ async def mcp_get_gateway( first sending data via HTTP POST. The server can send JSON-RPC requests and notifications on this stream. """ - mcp_server_base_url = get_mcp_server_base_url(request) - pass + mcp_server_endpoint = get_mcp_server_endpoint(request) + response_headers = {} + filtered_headers = { + k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS + } + + async def event_generator(): + """Connect to MCP server and process its events.""" + + async with httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) as client: + try: + async with aconnect_sse( + client, + "GET", + mcp_server_endpoint, + headers=filtered_headers, + ) as event_source: + if event_source.response.status_code != 200: + error_content = await event_source.response.aread() + raise HTTPException( + status_code=event_source.response.status_code, + detail=error_content, + ) + response_headers.update(dict(event_source.response.headers.items())) + + async for sse in event_source.aiter_sse(): + yield sse + + except httpx.StreamClosed as e: + print(f"Server stream closed: {e}", flush=True) + except Exception as e: # pylint: disable=broad-except + print(f"Error processing server events: {e}", flush=True) + + return StreamingResponse( + event_generator(), + media_type=CONTENT_TYPE_SSE, + headers={"X-Proxied-By": "mcp-gateway", **response_headers}, + ) @gateway.delete("/mcp/streamable") -async def mcp_delete_gateway( +async def mcp_delete_streamable_gateway( request: Request, ) -> Response: """ Forward a DELETE request to the MCP Streamable server for explicit session termination. """ - mcp_server_base_url = get_mcp_server_base_url(request) - pass + session_id = get_session_id(request) + if not session_store.session_exists(session_id): + raise HTTPException( + status_code=400, + detail="Session does not exist", + ) + mcp_server_endpoint = get_mcp_server_endpoint(request) + + async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: + try: + response = await client.delete( + url=mcp_server_endpoint, + headers={ + k: v + for k, v in request.headers.items() + if k.lower() + in { + "connection", + "accept", + "content-length", + "content-type", + MCP_SESSION_ID_HEADER, + } + }, + ) + await session_store.cleanup_session_lock(session_id) + return Response( + content=response.content, + status_code=response.status_code, + headers={ + "X-Proxied-By": "mcp-gateway", + **response.headers, + }, + ) + + except httpx.RequestError as e: + print(f"[MCP POST] Request error: {str(e)}") + raise HTTPException(status_code=500, detail="Request error") from e + except Exception as e: + print(f"[MCP POST] Unexpected error: {str(e)}") + raise HTTPException(status_code=500, detail="Unexpected error") from e