Add implementation for MCP streamable GET, POST and DELETE endpoints without push to explorer or guardrailing.

This commit is contained in:
Hemang
2025-05-22 15:02:56 +02:00
committed by Hemang Sarkar
parent f8bf7be405
commit 71e2ac9a06
4 changed files with 257 additions and 31 deletions
+39 -9
View File
@@ -58,6 +58,7 @@ class McpSession(BaseModel):
pending_error_messages: List[dict] = Field(default_factory=list)
# Lock to maintain in-order pushes to explorer
# and other session-related operations
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def load_guardrails(self) -> None:
@@ -292,24 +293,53 @@ class McpSessionsManager:
def __init__(self):
self._sessions: dict[str, McpSession] = {}
# Dictionary to store per-session locks.
# Used for session initialization and deletion.
self._session_locks: dict[str, asyncio.Lock] = {}
# Global lock to protect the locks dictionary itself
self._global_lock = asyncio.Lock()
def session_exists(self, session_id: str) -> bool:
"""Check if a session exists"""
return session_id in self._sessions
async def _get_session_lock(self, session_id: str) -> asyncio.Lock:
"""
Get a lock for a specific session ID, creating one if it doesn't exist.
Uses the global lock to protect access to the locks dictionary.
"""
async with self._global_lock:
if session_id not in self._session_locks:
self._session_locks[session_id] = asyncio.Lock()
return self._session_locks[session_id]
async def cleanup_session_lock(self, session_id: str) -> None:
"""Remove a session lock when it's no longer needed"""
async with self._global_lock:
if session_id in self._session_locks:
del self._session_locks[session_id]
async def initialize_session(
self, session_id: str, sse_header_attributes: SseHeaderAttributes
) -> None:
"""Initialize a new session"""
if session_id not in self._sessions:
session = McpSession(
session_id=session_id,
explorer_dataset=sse_header_attributes.explorer_dataset,
push_explorer=sse_header_attributes.push_explorer,
)
self._sessions[session_id] = session
# Load guardrails for the session from the explorer
await session.load_guardrails()
# Get the lock for this specific session
session_lock = await self._get_session_lock(session_id)
# Acquire the lock for this session
async with session_lock:
# Check again if session exists (it might have been created while waiting for the lock)
if session_id not in self._sessions:
session = McpSession(
session_id=session_id,
explorer_dataset=sse_header_attributes.explorer_dataset,
push_explorer=sse_header_attributes.push_explorer,
)
self._sessions[session_id] = session
# Load guardrails for the session from the explorer
await session.load_guardrails()
else:
print(f"Session {session_id} already exists, skipping initialization", flush=True)
def get_session(self, session_id: str) -> McpSession:
"""Get a session by ID"""
+1 -1
View File
@@ -48,4 +48,4 @@ def get_mcp_server_base_url(request: Request) -> str:
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return _convert_localhost_to_docker_host(mcp_server_base_url)
return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/")
+4 -7
View File
@@ -50,17 +50,11 @@ session_store = McpSessionsManager()
@gateway.post("/mcp/sse/messages/")
async def mcp_post_gateway(
async def mcp_post_sse_gateway(
request: Request,
) -> Response:
"""Proxy calls to the MCP Server tools"""
query_params = dict(request.query_params)
print("[MCP POST] Query params:", query_params, flush=True)
print(
"[MCP POST] Query params session_id:",
query_params.get("session_id"),
flush=True,
)
if not query_params.get("session_id"):
raise HTTPException(
status_code=400,
@@ -193,6 +187,9 @@ async def mcp_get_sse_gateway(
status_code=event_source.response.status_code,
detail=error_content,
)
response_headers.update(
dict(event_source.response.headers.items())
)
async for sse in event_source.aiter_sse():
if sse.event == "endpoint":
+213 -14
View File
@@ -1,33 +1,157 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
import json
from gateway.common.constants import CLIENT_TIMEOUT
from gateway.common.mcp_sessions_manager import McpSessionsManager, SseHeaderAttributes
from gateway.common.mcp_utils import get_mcp_server_base_url
from fastapi import APIRouter, Request, Response
import httpx
from httpx_sse import aconnect_sse
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
MCP_SESSION_ID_HEADER = "mcp-session-id"
gateway = APIRouter()
session_store = McpSessionsManager()
MCP_SESSION_ID_HEADER = "mcp-session-id"
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_SSE = "text/event-stream"
MCP_SERVER_POST_HEADERS = {
"connection",
"accept",
"content-length",
"content-type",
MCP_SESSION_ID_HEADER,
}
MCP_SERVER_SSE_HEADERS = {
"connection",
"accept",
"cache-control",
MCP_SESSION_ID_HEADER,
}
def get_session_id(request: Request) -> str | None:
def get_session_id(request: Request) -> str:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
if not session_id:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SESSION_ID_HEADER} header",
)
return session_id
def get_mcp_server_endpoint(request: Request) -> str:
"""
Extract the MCP server endpoint from the request headers.
"""
return get_mcp_server_base_url(request) + "/mcp/"
@gateway.post("/mcp/streamable")
async def mcp_post_gateway(
async def mcp_post_streamable_gateway(
request: Request,
) -> Response:
) -> StreamingResponse:
"""
Forward a POST request to the MCP Streamable server.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
body = await request.body()
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
# Determine if this is an initialization request, only for our session tracking
try:
raw_message = json.loads(body)
is_initialization_request = (
isinstance(raw_message, dict)
and raw_message.get("method") == "initialize"
and "jsonrpc" in raw_message
)
except json.JSONDecodeError:
# Let the server handle the validation error
pass
mcp_server_endpoint = get_mcp_server_endpoint(request)
filtered_headers = {
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_POST_HEADERS
}
if (
session_id
and not is_initialization_request
and not session_store.session_exists(session_id)
):
raise HTTPException(status_code=404, detail="Invalid or expired session ID")
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.post(
url=mcp_server_endpoint,
headers=filtered_headers,
content=body,
follow_redirects=True,
)
# If we received a session ID from server, register it in our session store
# This happens in initialization responses
resp_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
if resp_session_id and not session_store.session_exists(resp_session_id):
await session_store.initialize_session(
resp_session_id, sse_header_attributes
)
# If the response is JSON, return it directly
if response.headers.get("content-type", "") == CONTENT_TYPE_JSON:
return Response(
content=response.content,
status_code=response.status_code,
headers={"X-Proxied-By": "mcp-gateway", **response.headers},
)
# Else return SSE streaming response
async def event_generator():
# Events have two parts:
# 1. event: {type} -> contains the type of event
# 2. data: {data} -> contains the actual message
# We are reading line by line so we need to buffer so that we can
# send the entire event (with both type and data) together.
# Once we receive an empty line, we end the stream.
buffer = ""
async for line in response.aiter_lines():
if line.strip():
if buffer:
complete_event = buffer + "\n" + line + "\n\n"
yield complete_event
# Clear the buffer for the next event
buffer = ""
else:
buffer = line
else:
# End stream here when line is empty.
break
return StreamingResponse(
event_generator(),
media_type=CONTENT_TYPE_SSE,
headers={
"X-Proxied-By": "mcp-gateway",
**response.headers,
},
)
except httpx.RequestError as e:
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
except Exception as e:
print(f"[MCP POST] Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Unexpected error") from e
@gateway.get("/mcp/streamable")
async def mcp_get_gateway(
async def mcp_get_streamable_gateway(
request: Request,
) -> Response:
"""
@@ -37,16 +161,91 @@ async def mcp_get_gateway(
first sending data via HTTP POST. The server can send JSON-RPC requests
and notifications on this stream.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
mcp_server_endpoint = get_mcp_server_endpoint(request)
response_headers = {}
filtered_headers = {
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
}
async def event_generator():
"""Connect to MCP server and process its events."""
async with httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) as client:
try:
async with aconnect_sse(
client,
"GET",
mcp_server_endpoint,
headers=filtered_headers,
) as event_source:
if event_source.response.status_code != 200:
error_content = await event_source.response.aread()
raise HTTPException(
status_code=event_source.response.status_code,
detail=error_content,
)
response_headers.update(dict(event_source.response.headers.items()))
async for sse in event_source.aiter_sse():
yield sse
except httpx.StreamClosed as e:
print(f"Server stream closed: {e}", flush=True)
except Exception as e: # pylint: disable=broad-except
print(f"Error processing server events: {e}", flush=True)
return StreamingResponse(
event_generator(),
media_type=CONTENT_TYPE_SSE,
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
@gateway.delete("/mcp/streamable")
async def mcp_delete_gateway(
async def mcp_delete_streamable_gateway(
request: Request,
) -> Response:
"""
Forward a DELETE request to the MCP Streamable server for explicit session termination.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
session_id = get_session_id(request)
if not session_store.session_exists(session_id):
raise HTTPException(
status_code=400,
detail="Session does not exist",
)
mcp_server_endpoint = get_mcp_server_endpoint(request)
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.delete(
url=mcp_server_endpoint,
headers={
k: v
for k, v in request.headers.items()
if k.lower()
in {
"connection",
"accept",
"content-length",
"content-type",
MCP_SESSION_ID_HEADER,
}
},
)
await session_store.cleanup_session_lock(session_id)
return Response(
content=response.content,
status_code=response.status_code,
headers={
"X-Proxied-By": "mcp-gateway",
**response.headers,
},
)
except httpx.RequestError as e:
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
except Exception as e:
print(f"[MCP POST] Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Unexpected error") from e