From 40ec6d2db2d1028b2b45e6a0e715e588094176cb Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 6 May 2025 19:14:03 +0530 Subject: [PATCH] Add MCP SSE server proxying in gateway. --- gateway/common/mcp_sessions_manager.py | 182 ++++++++++++++ gateway/routes/mcp_sse.py | 332 +++++++++++++++++++++++++ gateway/serve.py | 3 + pyproject.toml | 1 + 4 files changed, 518 insertions(+) create mode 100644 gateway/common/mcp_sessions_manager.py create mode 100644 gateway/routes/mcp_sse.py diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py new file mode 100644 index 0000000..e9c77be --- /dev/null +++ b/gateway/common/mcp_sessions_manager.py @@ -0,0 +1,182 @@ +"""MCP Sessions Manager related classes""" + +import asyncio +import contextlib +import os +import random + +from typing import Any, Dict, List, Optional + +from invariant_sdk.async_client import AsyncClient +from invariant_sdk.types.append_messages import AppendMessagesRequest +from invariant_sdk.types.push_traces import PushTracesRequest +from pydantic import BaseModel, Field, PrivateAttr +from starlette.datastructures import Headers + +DEFAULT_API_URL = "https://explorer.invariantlabs.ai" + + +class McpSession(BaseModel): + """ + Represents a single MCP session. + """ + + session_id: str + messages: List[Dict[str, Any]] = Field(default_factory=list) + metadata: Dict[str, Any] = Field(default_factory=dict) + id_to_method_mapping: Dict[int, str] = Field(default_factory=dict) + explorer_dataset: str + push_explorer: bool + trace_id: Optional[str] = None + last_trace_length: int = 0 + + # Lock to maintain in-order pushes to explorer + _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) + + @contextlib.asynccontextmanager + async def session_lock(self): + """ + Context manager for the session lock. + + Usage: + async with session.session_lock(): + # Code that requires exclusive access to the session + """ + async with self._lock: + yield + + async def add_message(self, message: Dict[str, Any]) -> None: + """ + Add a message to the session and optionally push to explorer. + + Args: + message: The message to add + """ + async with self.session_lock(): + # pylint: disable=no-member + self.messages.append(message) + # If push_explorer is enabled, push the trace + if self.push_explorer: + await self._push_trace_update() + + async def _push_trace_update(self) -> None: + """ + Push trace updates to the explorer. + + If a trace doesn't exist, create a new one. If it does, append new messages. + + This is an internal method that should only be called within a lock. + """ + try: + client = AsyncClient( + api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL), + ) + + # If no trace exists, create a new one + if not self.trace_id: + # pylint: disable=no-member + metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])} + if self.metadata.get("mcp_client_name"): + metadata["mcp_client"] = self.metadata.get("mcp_client_name") + if self.metadata.get("mcp_server_name"): + metadata["mcp_server"] = self.metadata.get("mcp_server_name") + + response = await client.push_trace( + PushTracesRequest( + messages=[self.messages], + dataset=self.explorer_dataset, + metadata=[metadata], + ) + ) + self.trace_id = response.id[0] + else: + new_messages = self.messages[self.last_trace_length :] + if new_messages: + await client.append_messages( + AppendMessagesRequest( + trace_id=self.trace_id, + messages=new_messages, + ) + ) + self.last_trace_length = len(self.messages) + except Exception as e: # pylint: disable=broad-except + print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}") + + +class SseHeaderAttributes(BaseModel): + """ + A Pydantic model to represent header attributes. + """ + + push_explorer: bool + explorer_dataset: str + + @classmethod + def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes": + """ + Create an instance from FastAPI request headers. + + Args: + headers: FastAPI Request headers + + Returns: + SseHeaderAttributes: An instance with values extracted from headers + """ + # Extract and process header values + project_name = headers.get("PROJECT-NAME") + push_explorer_header = headers.get("PUSH-EXPLORER", "false").lower() + + # Determine explorer_dataset + if project_name: + explorer_dataset = project_name + else: + explorer_dataset = f"mcp-capture-{random.randint(1, 100)}" + + # Determine push_explorer + push_explorer = push_explorer_header == "true" + + # Create and return instance + return cls(push_explorer=push_explorer, explorer_dataset=explorer_dataset) + + +class McpSessionsManager: + """ + A class to manage MCP sessions and their messages. + """ + + def __init__(self): + self._sessions: dict[str, McpSession] = {} + + def session_exists(self, session_id: str) -> bool: + """Check if a session exists""" + return session_id in self._sessions + + def initialize_session( + self, session_id: str, sse_header_attributes: SseHeaderAttributes + ) -> None: + """Initialize a new session""" + if session_id not in self._sessions: + self._sessions[session_id] = McpSession( + session_id=session_id, + explorer_dataset=sse_header_attributes.explorer_dataset, + push_explorer=sse_header_attributes.push_explorer, + ) + + def get_session(self, session_id: str) -> McpSession: + """Get a session by ID""" + if session_id not in self._sessions: + raise ValueError(f"Session {session_id} does not exist.") + return self._sessions.get(session_id) + + async def add_message_to_session( + self, session_id: str, message: Dict[str, Any] + ) -> None: + """ + Add a message to a session and push to explorer if enabled. + + Args: + session_id: The session ID + message: The message to add + """ + session = self.get_session(session_id) + await session.add_message(message) diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py new file mode 100644 index 0000000..42cb70c --- /dev/null +++ b/gateway/routes/mcp_sse.py @@ -0,0 +1,332 @@ +"""Gateway service to forward requests to the MCP SSE servers""" + +import asyncio +import json +import re +from typing import Tuple + +import httpx +from httpx_sse import aconnect_sse, ServerSentEvent +from fastapi import APIRouter, HTTPException, Request, Response +from fastapi.responses import StreamingResponse + +from gateway.common.constants import ( + CLIENT_TIMEOUT, +) +from gateway.common.mcp_sessions_manager import ( + McpSessionsManager, + SseHeaderAttributes, +) + + +MCP_METHOD = "method" +MCP_TOOL_CALL = "tools/call" +MCP_LIST_TOOLS = "tools/list" +MCP_PARAMS = "params" +MCP_RESULT = "result" +MCP_SERVER_INFO = "serverInfo" +MCP_CLIENT_INFO = "clientInfo" +MCP_SERVER_POST_HEADERS = { + "connection", + "accept", + "content-length", + "content-type", +} +MCP_SERVER_SSE_HEADERS = { + "connection", + "accept", + "cache-control", +} + +gateway = APIRouter() +session_store = McpSessionsManager() + + +@gateway.post("/mcp/sse/messages/") +async def mcp_post_gateway( + request: Request, +) -> Response: + """Proxy calls to the MCP Server tools""" + query_params = dict(request.query_params) + if not query_params.get("session_id"): + return HTTPException( + status_code=400, + detail="Missing 'session_id' query parameter", + ) + if not session_store.session_exists(query_params.get("session_id")): + return HTTPException( + status_code=400, + detail="Session does not exist", + ) + if not request.headers.get("mcp-server-base-url"): + return HTTPException( + status_code=400, + detail="Missing 'mcp-server-base-url' header", + ) + + session_id = query_params.get("session_id") + mcp_server_messages_endpoint = ( + _convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url")) + + "/messages/?" + + session_id + ) + request_body_bytes = await request.body() + request_json = json.loads(request_body_bytes) + session = session_store.get_session(session_id) + if request_json.get(MCP_METHOD) and request_json.get("id"): + session.id_to_method_mapping[request_json.get("id")] = request_json.get( + MCP_METHOD + ) + if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get( + MCP_CLIENT_INFO + ): + session.metadata["mcp_client_name"] = ( + request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") + ) + + if request_json.get(MCP_METHOD) == MCP_TOOL_CALL: + _hook_tool_call(session_id=session_id, request_json=request_json) + + async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: + try: + response = await client.post( + url=mcp_server_messages_endpoint, + headers={ + k: v + for k, v in request.headers.items() + if k.lower() in MCP_SERVER_POST_HEADERS + }, + json=request_json, + params=query_params, + ) + return Response( + content=response.content, + status_code=response.status_code, + headers={ + "X-Proxied-By": "mcp-gateway", + **response.headers, + }, + ) + + except httpx.RequestError as e: + print(f"[MCP POST] Request error: {str(e)}") + raise HTTPException(status_code=500, detail="Request error") from e + except Exception as e: + print(f"[MCP POST] Unexpected error: {str(e)}") + raise HTTPException(status_code=500, detail="Unexpected error") from e + + +@gateway.get("/mcp/sse") +async def mcp_get_sse_gateway( + request: Request, +) -> StreamingResponse: + """Proxy calls to the MCP Server tools""" + mcp_server_base_url = request.headers.get("mcp-server-base-url") + if not mcp_server_base_url: + raise HTTPException( + status_code=400, + detail="Missing 'mcp-server-base-url' header", + ) + mcp_server_sse_endpoint = ( + _convert_localhost_to_docker_host(mcp_server_base_url) + "/sse" + ) + + query_params = dict(request.query_params) + response_headers = {} + + async def event_generator(): + async with httpx.AsyncClient( + timeout=httpx.Timeout(CLIENT_TIMEOUT), + headers={ + k: v + for k, v in request.headers.items() + if k.lower() in MCP_SERVER_SSE_HEADERS + }, + ) as client: + try: + async with aconnect_sse( + client, + "GET", + mcp_server_sse_endpoint, + params=query_params, + ) as event_source: + if event_source.response.status_code != 200: + error_content = await event_source.response.aread() + raise HTTPException( + status_code=event_source.response.status_code, + detail=error_content, + ) + + session_id = None + + async for sse in event_source.aiter_sse(): + event_bytes = ( + f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8") + ) + match sse.event: + case "endpoint": + ( + event_bytes, + session_id, + ) = _handle_endpoint_event( + sse, + sse_header_attributes=SseHeaderAttributes.from_request_headers( + request.headers + ), + ) + case "message": + if session_id: + event_bytes = _handle_message_event( + session_id=session_id, sse=sse + ) + yield event_bytes + + except httpx.StreamClosed as e: + print(f"[MCP SSE] Stream closed: {str(e)}", flush=True) + except httpx.RequestError as e: + print(f"[MCP SSE] Request error: {str(e)}", flush=True) + except Exception as e: # pylint: disable=broad-except + print(f"[MCP SSE] Unexpected error: {str(e)}", flush=True) + + # Return the streaming response + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"X-Proxied-By": "mcp-gateway", **response_headers}, + ) + + +def _hook_tool_call(session_id: str, request_json: dict) -> None: + """ + Hook to process the request JSON before sending it to the MCP server. + + Args: + session_id (str): The session ID associated with the request. + request_json (dict): The request JSON to be processed. + """ + tool_call = { + "id": f"call_{request_json.get('id')}", + "type": "function", + "function": { + "name": request_json.get(MCP_PARAMS).get("name"), + "arguments": request_json.get(MCP_PARAMS).get("arguments"), + }, + } + message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} + # Push trace to the explorer - don't block on its response + asyncio.create_task(session_store.add_message_to_session(session_id, message)) + + +def _hook_tool_call_response(session_id: str, response_json: dict) -> None: + """ + + Hook to process the response JSON after receiving it from the MCP server. + Args: + session_id (str): The session ID associated with the request. + response_json (dict): The response JSON to be processed. + """ + message = { + "role": "tool", + "tool_call_id": f"call_{response_json.get('id')}", + "content": response_json.get(MCP_RESULT).get("content"), + "error": response_json.get(MCP_RESULT).get("error"), + } + # Push trace to the explorer - don't block on its response + asyncio.create_task(session_store.add_message_to_session(session_id, message)) + + +def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: + """ + Convert localhost or 127.0.0.1 in an address to host.docker.internal + + Args: + mcp_server_base_url (str): The original server address from the header + + Returns: + str: Modified server address with localhost references changed to host.docker.internal + """ + if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url: + # Replace localhost or 127.0.0.1 with host.docker.internal + modified_address = re.sub( + r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)", + r"\1host.docker.internal\2", + mcp_server_base_url, + ) + return modified_address + + return mcp_server_base_url + + +def _handle_endpoint_event( + sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes +) -> Tuple[bytes, str]: + """ + Handle the endpoint event type and modify the data accordingly. + For endpoint events, we need to rewrite the endpoint to use our gateway. + + Args: + sse (ServerSentEvent): The original SSE object. + sse_header_attributes (SseHeaderAttributes): The header attributes from the request. + + Returns: + bytes: Modified SSE data as bytes. + str: session_id extracted from the data. + """ + # Extract session_id + match = re.search(r"session_id=([^&\s]+)", sse.data) + if match: + session_id = match.group(1) + # Initialize this session in our store if needed + if not session_store.session_exists(session_id): + session_store.initialize_session(session_id, sse_header_attributes) + + # Rewrite the endpoint to use our gateway + modified_data = sse.data.replace( + "/messages/?session_id=", + "/api/v1/gateway/mcp/sse/messages/?session_id=", + ) + event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8") + return event_bytes, session_id + + +def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: + """ + Handle the message event type. + + Args: + session_id (str): The session ID associated with the request. + sse (ServerSentEvent): The original SSE object. + """ + event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8") + session = session_store.get_session(session_id) + try: + response_json = json.loads(sse.data) + + if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get( + MCP_SERVER_INFO + ): + session.metadata["mcp_server_name"] = ( + response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") + ) + + method = session.id_to_method_mapping.get(response_json.get("id")) + if method == MCP_TOOL_CALL: + _hook_tool_call_response( + session_id=session_id, + response_json=response_json, + ) + elif method == MCP_LIST_TOOLS: + session_store.get_session(session_id).metadata["tools"] = response_json.get( + MCP_RESULT + ).get("tools") + except json.JSONDecodeError as e: + print( + f"[MCP SSE] Error parsing message JSON: {e}", + flush=True, + ) + except Exception as e: # pylint: disable=broad-except + print( + f"[MCP SSE] Error processing message: {e}", + flush=True, + ) + return event_bytes diff --git a/gateway/serve.py b/gateway/serve.py index 2568d6c..b8c8c5e 100644 --- a/gateway/serve.py +++ b/gateway/serve.py @@ -7,6 +7,7 @@ from starlette_compress import CompressMiddleware from gateway.routes.anthropic import gateway as anthropic_gateway from gateway.routes.gemini import gateway as gemini_gateway from gateway.routes.open_ai import gateway as open_ai_gateway +from gateway.routes.mcp_sse import gateway as mcp_sse_gateway app = fastapi.app = fastapi.FastAPI( docs_url="/api/v1/gateway/docs", @@ -30,6 +31,8 @@ router.include_router(anthropic_gateway, prefix="/gateway", tags=["anthropic_gat router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"]) +router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"]) + app.include_router(router) diff --git a/pyproject.toml b/pyproject.toml index 561c288..3441f81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.12" dependencies = [ "fastapi==0.115.7", "httpx==0.28.1", + "httpx-sse==0.4.0", "invariant-sdk>=0.0.11", "starlette-compress==1.4.0", "uvicorn==0.34.0"