mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-06 05:03:56 +02:00
Add implementation for MCP streamable GET, POST and DELETE endpoints without push to explorer or guardrailing.
This commit is contained in:
@@ -58,6 +58,7 @@ class McpSession(BaseModel):
|
||||
pending_error_messages: List[dict] = Field(default_factory=list)
|
||||
|
||||
# Lock to maintain in-order pushes to explorer
|
||||
# and other session-related operations
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
async def load_guardrails(self) -> None:
|
||||
@@ -292,24 +293,53 @@ class McpSessionsManager:
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, McpSession] = {}
|
||||
# Dictionary to store per-session locks.
|
||||
# Used for session initialization and deletion.
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
# Global lock to protect the locks dictionary itself
|
||||
self._global_lock = asyncio.Lock()
|
||||
|
||||
def session_exists(self, session_id: str) -> bool:
|
||||
"""Check if a session exists"""
|
||||
return session_id in self._sessions
|
||||
|
||||
async def _get_session_lock(self, session_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
Get a lock for a specific session ID, creating one if it doesn't exist.
|
||||
Uses the global lock to protect access to the locks dictionary.
|
||||
"""
|
||||
async with self._global_lock:
|
||||
if session_id not in self._session_locks:
|
||||
self._session_locks[session_id] = asyncio.Lock()
|
||||
return self._session_locks[session_id]
|
||||
|
||||
async def cleanup_session_lock(self, session_id: str) -> None:
|
||||
"""Remove a session lock when it's no longer needed"""
|
||||
async with self._global_lock:
|
||||
if session_id in self._session_locks:
|
||||
del self._session_locks[session_id]
|
||||
|
||||
async def initialize_session(
|
||||
self, session_id: str, sse_header_attributes: SseHeaderAttributes
|
||||
) -> None:
|
||||
"""Initialize a new session"""
|
||||
if session_id not in self._sessions:
|
||||
session = McpSession(
|
||||
session_id=session_id,
|
||||
explorer_dataset=sse_header_attributes.explorer_dataset,
|
||||
push_explorer=sse_header_attributes.push_explorer,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
# Load guardrails for the session from the explorer
|
||||
await session.load_guardrails()
|
||||
# Get the lock for this specific session
|
||||
session_lock = await self._get_session_lock(session_id)
|
||||
|
||||
# Acquire the lock for this session
|
||||
async with session_lock:
|
||||
# Check again if session exists (it might have been created while waiting for the lock)
|
||||
if session_id not in self._sessions:
|
||||
session = McpSession(
|
||||
session_id=session_id,
|
||||
explorer_dataset=sse_header_attributes.explorer_dataset,
|
||||
push_explorer=sse_header_attributes.push_explorer,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
# Load guardrails for the session from the explorer
|
||||
await session.load_guardrails()
|
||||
else:
|
||||
print(f"Session {session_id} already exists, skipping initialization", flush=True)
|
||||
|
||||
def get_session(self, session_id: str) -> McpSession:
|
||||
"""Get a session by ID"""
|
||||
|
||||
@@ -48,4 +48,4 @@ def get_mcp_server_base_url(request: Request) -> str:
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
return _convert_localhost_to_docker_host(mcp_server_base_url)
|
||||
return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/")
|
||||
|
||||
@@ -50,17 +50,11 @@ session_store = McpSessionsManager()
|
||||
|
||||
|
||||
@gateway.post("/mcp/sse/messages/")
|
||||
async def mcp_post_gateway(
|
||||
async def mcp_post_sse_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
query_params = dict(request.query_params)
|
||||
print("[MCP POST] Query params:", query_params, flush=True)
|
||||
print(
|
||||
"[MCP POST] Query params session_id:",
|
||||
query_params.get("session_id"),
|
||||
flush=True,
|
||||
)
|
||||
if not query_params.get("session_id"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -193,6 +187,9 @@ async def mcp_get_sse_gateway(
|
||||
status_code=event_source.response.status_code,
|
||||
detail=error_content,
|
||||
)
|
||||
response_headers.update(
|
||||
dict(event_source.response.headers.items())
|
||||
)
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "endpoint":
|
||||
|
||||
@@ -1,33 +1,157 @@
|
||||
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
|
||||
|
||||
import json
|
||||
|
||||
from gateway.common.constants import CLIENT_TIMEOUT
|
||||
from gateway.common.mcp_sessions_manager import McpSessionsManager, SseHeaderAttributes
|
||||
from gateway.common.mcp_utils import get_mcp_server_base_url
|
||||
|
||||
from fastapi import APIRouter, Request, Response
|
||||
import httpx
|
||||
|
||||
from httpx_sse import aconnect_sse
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
|
||||
gateway = APIRouter()
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_SSE = "text/event-stream"
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
"content-length",
|
||||
"content-type",
|
||||
MCP_SESSION_ID_HEADER,
|
||||
}
|
||||
MCP_SERVER_SSE_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
"cache-control",
|
||||
MCP_SESSION_ID_HEADER,
|
||||
}
|
||||
|
||||
|
||||
def get_session_id(request: Request) -> str | None:
|
||||
def get_session_id(request: Request) -> str:
|
||||
"""Extract the session ID from request headers."""
|
||||
return request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
if not session_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SESSION_ID_HEADER} header",
|
||||
)
|
||||
return session_id
|
||||
|
||||
|
||||
def get_mcp_server_endpoint(request: Request) -> str:
|
||||
"""
|
||||
Extract the MCP server endpoint from the request headers.
|
||||
"""
|
||||
return get_mcp_server_base_url(request) + "/mcp/"
|
||||
|
||||
|
||||
@gateway.post("/mcp/streamable")
|
||||
async def mcp_post_gateway(
|
||||
async def mcp_post_streamable_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Forward a POST request to the MCP Streamable server.
|
||||
"""
|
||||
mcp_server_base_url = get_mcp_server_base_url(request)
|
||||
pass
|
||||
body = await request.body()
|
||||
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Determine if this is an initialization request, only for our session tracking
|
||||
try:
|
||||
raw_message = json.loads(body)
|
||||
is_initialization_request = (
|
||||
isinstance(raw_message, dict)
|
||||
and raw_message.get("method") == "initialize"
|
||||
and "jsonrpc" in raw_message
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Let the server handle the validation error
|
||||
pass
|
||||
|
||||
mcp_server_endpoint = get_mcp_server_endpoint(request)
|
||||
filtered_headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_POST_HEADERS
|
||||
}
|
||||
if (
|
||||
session_id
|
||||
and not is_initialization_request
|
||||
and not session_store.session_exists(session_id)
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Invalid or expired session ID")
|
||||
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url=mcp_server_endpoint,
|
||||
headers=filtered_headers,
|
||||
content=body,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
# If we received a session ID from server, register it in our session store
|
||||
# This happens in initialization responses
|
||||
resp_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
||||
if resp_session_id and not session_store.session_exists(resp_session_id):
|
||||
await session_store.initialize_session(
|
||||
resp_session_id, sse_header_attributes
|
||||
)
|
||||
|
||||
# If the response is JSON, return it directly
|
||||
if response.headers.get("content-type", "") == CONTENT_TYPE_JSON:
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response.headers},
|
||||
)
|
||||
|
||||
# Else return SSE streaming response
|
||||
async def event_generator():
|
||||
# Events have two parts:
|
||||
# 1. event: {type} -> contains the type of event
|
||||
# 2. data: {data} -> contains the actual message
|
||||
# We are reading line by line so we need to buffer so that we can
|
||||
# send the entire event (with both type and data) together.
|
||||
# Once we receive an empty line, we end the stream.
|
||||
buffer = ""
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip():
|
||||
if buffer:
|
||||
complete_event = buffer + "\n" + line + "\n\n"
|
||||
yield complete_event
|
||||
# Clear the buffer for the next event
|
||||
buffer = ""
|
||||
else:
|
||||
buffer = line
|
||||
else:
|
||||
# End stream here when line is empty.
|
||||
break
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type=CONTENT_TYPE_SSE,
|
||||
headers={
|
||||
"X-Proxied-By": "mcp-gateway",
|
||||
**response.headers,
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print(f"[MCP POST] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
except Exception as e:
|
||||
print(f"[MCP POST] Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
|
||||
@gateway.get("/mcp/streamable")
|
||||
async def mcp_get_gateway(
|
||||
async def mcp_get_streamable_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -37,16 +161,91 @@ async def mcp_get_gateway(
|
||||
first sending data via HTTP POST. The server can send JSON-RPC requests
|
||||
and notifications on this stream.
|
||||
"""
|
||||
mcp_server_base_url = get_mcp_server_base_url(request)
|
||||
pass
|
||||
mcp_server_endpoint = get_mcp_server_endpoint(request)
|
||||
response_headers = {}
|
||||
filtered_headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
|
||||
}
|
||||
|
||||
async def event_generator():
|
||||
"""Connect to MCP server and process its events."""
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) as client:
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
mcp_server_endpoint,
|
||||
headers=filtered_headers,
|
||||
) as event_source:
|
||||
if event_source.response.status_code != 200:
|
||||
error_content = await event_source.response.aread()
|
||||
raise HTTPException(
|
||||
status_code=event_source.response.status_code,
|
||||
detail=error_content,
|
||||
)
|
||||
response_headers.update(dict(event_source.response.headers.items()))
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
yield sse
|
||||
|
||||
except httpx.StreamClosed as e:
|
||||
print(f"Server stream closed: {e}", flush=True)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Error processing server events: {e}", flush=True)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type=CONTENT_TYPE_SSE,
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
|
||||
)
|
||||
|
||||
|
||||
@gateway.delete("/mcp/streamable")
|
||||
async def mcp_delete_gateway(
|
||||
async def mcp_delete_streamable_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""
|
||||
Forward a DELETE request to the MCP Streamable server for explicit session termination.
|
||||
"""
|
||||
mcp_server_base_url = get_mcp_server_base_url(request)
|
||||
pass
|
||||
session_id = get_session_id(request)
|
||||
if not session_store.session_exists(session_id):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Session does not exist",
|
||||
)
|
||||
mcp_server_endpoint = get_mcp_server_endpoint(request)
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
response = await client.delete(
|
||||
url=mcp_server_endpoint,
|
||||
headers={
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower()
|
||||
in {
|
||||
"connection",
|
||||
"accept",
|
||||
"content-length",
|
||||
"content-type",
|
||||
MCP_SESSION_ID_HEADER,
|
||||
}
|
||||
},
|
||||
)
|
||||
await session_store.cleanup_session_lock(session_id)
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
"X-Proxied-By": "mcp-gateway",
|
||||
**response.headers,
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print(f"[MCP POST] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
except Exception as e:
|
||||
print(f"[MCP POST] Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
Reference in New Issue
Block a user