Files
invariant-gateway/gateway/mcp/streamable.py
2025-06-05 11:58:12 +02:00

434 lines
16 KiB
Python

"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
import json
from typing import Any
import httpx
from httpx_sse import aconnect_sse
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
CONTENT_TYPE_HEADER,
CONTENT_TYPE_JSON,
CONTENT_TYPE_EVENT_STREAM,
)
from gateway.mcp.constants import (
INVARIANT_SESSION_ID_PREFIX,
MCP_CUSTOM_HEADER_PREFIX,
UTF_8,
)
from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.mcp.mcp_transport_base import McpTransportBase
gateway = APIRouter()
mcp_sessions_manager = McpSessionsManager()
MCP_SESSION_ID_HEADER = "mcp-session-id"
MCP_SERVER_POST_AND_DELETE_HEADERS = {
"connection",
"accept",
CONTENT_TYPE_HEADER,
MCP_SESSION_ID_HEADER,
}
MCP_SERVER_GET_HEADERS = {
"connection",
"accept",
"cache-control",
MCP_SESSION_ID_HEADER,
}
@gateway.post("/mcp/streamable")
async def mcp_post_streamable_gateway(
request: Request,
):
"""Forward a POST request to the MCP Streamable server using transport strategy."""
return await create_streamable_transport_and_handle_request(
request, "POST", mcp_sessions_manager
)
@gateway.get("/mcp/streamable")
async def mcp_get_streamable_gateway(request: Request) -> StreamingResponse:
"""Forward a GET request to the MCP Streamable server using transport strategy."""
return await create_streamable_transport_and_handle_request(
request, "GET", mcp_sessions_manager
)
@gateway.delete("/mcp/streamable")
async def mcp_delete_streamable_gateway(request: Request) -> Response:
"""Forward a DELETE request to the MCP Streamable server using transport strategy."""
return await create_streamable_transport_and_handle_request(
request, "DELETE", mcp_sessions_manager
)
async def create_streamable_transport_and_handle_request(
request: Request, method: str, session_store: McpSessionsManager
) -> Response | StreamingResponse:
"""Integration function for streamable routes."""
streamable_transport = StreamableTransport(session_store)
return await streamable_transport.handle_communication(
request=request, method=method
)
class StreamableTransport(McpTransportBase):
"""
Streamable HTTP transport implementation for MCP communication.
Handles HTTP POST/GET/DELETE requests with JSON and streaming responses.
"""
async def initialize_session(
self,
**kwargs,
) -> str:
"""Initialize streamable HTTP session."""
session_id: str | None = kwargs.get("session_id", None)
session_attributes: McpAttributes | None = kwargs.get(
"session_attributes", None
)
is_initialization_request: bool = kwargs.get("is_initialization_request", False)
if session_id and self.session_store.session_exists(session_id):
return session_id
if is_initialization_request and not session_id:
session_id = self.generate_session_id()
if (
session_id
and not self.session_store.session_exists(session_id)
and session_attributes
):
await self.session_store.initialize_session(session_id, session_attributes)
return session_id
async def handle_post_request(
self, request: Request, request_body: dict[str, Any]
) -> Response | StreamingResponse:
"""Handle POST request to streamable endpoint."""
session_attributes = McpAttributes.from_request_headers(request.headers)
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
is_initialization_request = self._is_initialization_request(request_body)
# Handle session initialization
if session_id:
self.update_tool_call_id_in_session(
self.session_store.get_session(session_id), request_body
)
elif is_initialization_request:
session_id = await self.initialize_session(
session_attributes=session_attributes, is_initialization_request=True
)
# Process request if not initialization
if not is_initialization_request:
request_interception_result = await self._process_non_init_request(
session_id, request_body
)
if request_interception_result:
return request_interception_result
# Forward to MCP server
return await self._forward_to_mcp_server(
request,
request_body,
session_id,
session_attributes,
is_initialization_request,
)
async def handle_get_request(self, request: Request) -> StreamingResponse:
"""Handle GET request for server-initiated communication."""
mcp_server_endpoint = self._get_mcp_server_endpoint(request)
response_headers = {}
filtered_headers = {}
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_GET_HEADERS:
filtered_headers[k] = v
async def event_generator():
async with httpx.AsyncClient(
timeout=httpx.Timeout(CLIENT_TIMEOUT)
) as client:
try:
async with aconnect_sse(
client,
"GET",
mcp_server_endpoint,
headers=filtered_headers,
) as event_source:
if event_source.response.status_code != 200:
error_content = await event_source.response.aread()
raise HTTPException(
status_code=event_source.response.status_code,
detail=error_content,
)
response_headers.update(
dict(event_source.response.headers.items())
)
async for sse in event_source.aiter_sse():
yield sse
except httpx.StreamClosed as e:
print(f"Server stream closed: {e}")
except Exception as e: # pylint: disable=broad-except
print(f"Error processing server events: {e}")
return StreamingResponse(
event_generator(),
media_type=CONTENT_TYPE_EVENT_STREAM,
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
async def handle_delete_request(self, request: Request) -> Response:
"""Handle DELETE request for session termination."""
session_id = self._get_session_id(request)
if not self.session_store.session_exists(session_id):
raise HTTPException(status_code=400, detail="Session does not exist")
if session_id.startswith(INVARIANT_SESSION_ID_PREFIX):
return Response(
content="", status_code=200, headers={"X-Proxied-By": "mcp-gateway"}
)
mcp_server_endpoint = self._get_mcp_server_endpoint(request)
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.delete(
url=mcp_server_endpoint,
headers=self._get_headers_for_mcp_post_and_delete(request),
)
await self.session_store.cleanup_session_lock(session_id)
return Response(
content=response.content,
status_code=response.status_code,
headers={"X-Proxied-By": "mcp-gateway", **response.headers},
)
except httpx.RequestError as e:
print(f"[MCP DELETE] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
async def handle_communication(self, **kwargs) -> Response | StreamingResponse:
"""Main communication handler for streamable transport."""
request = kwargs.get("request")
method = kwargs.get("method", "POST")
if method == "POST":
request_body = json.loads(await request.body())
return await self.handle_post_request(request, request_body)
elif method == "GET":
return await self.handle_get_request(request)
elif method == "DELETE":
return await self.handle_delete_request(request)
else:
raise HTTPException(status_code=405, detail="Method not allowed")
async def _process_non_init_request(
self, session_id: str, request_body: dict[str, Any]
) -> Response | None:
"""Process non-initialization requests for guardrails."""
processed_request, is_blocked = await self.process_outgoing_request(
session_id, request_body
)
if is_blocked:
return Response(
content=json.dumps(processed_request),
status_code=400,
media_type=CONTENT_TYPE_JSON,
)
return None
async def _forward_to_mcp_server(
self,
request: Request,
request_body: dict[str, Any],
session_id: str,
session_attributes: McpAttributes,
is_initialization_request: bool,
) -> Response | StreamingResponse:
"""Forward request to MCP server and handle response."""
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.post(
url=self._get_mcp_server_endpoint(request),
headers=self._get_headers_for_mcp_post_and_delete(request),
content=json.dumps(request_body).encode(),
follow_redirects=True,
)
# Handle session ID from MCP server response
resp_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
if resp_session_id:
if not self.session_store.session_exists(resp_session_id):
await self.session_store.initialize_session(
resp_session_id, session_attributes
)
session_id = resp_session_id
elif (
is_initialization_request
and not self.session_store.session_exists(session_id)
):
await self.session_store.initialize_session(
session_id, session_attributes
)
# Update client info for initialization requests
if is_initialization_request:
self.update_mcp_client_info_in_session(
self.session_store.get_session(session_id), request_body
)
# Handle response based on content type
if response.headers.get(CONTENT_TYPE_HEADER) == CONTENT_TYPE_JSON:
return await self._handle_json_response(
session_id, is_initialization_request, response
)
else:
return await self._handle_streaming_response(
session_id, is_initialization_request, response
)
except httpx.RequestError as e:
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
async def _handle_json_response(
self, session_id: str, is_initialization_request: bool, response: httpx.Response
) -> Response:
"""Handle JSON response from MCP server."""
response_content = response.content
response_body = (
json.loads(response_content.decode(UTF_8)) if response_content else {}
)
if response_body:
self._update_mcp_response_info_in_session(session_id, response_body, True)
response_code = response.status_code
if not is_initialization_request and response_body:
processed_response, blocked = await self.process_incoming_response(
session_id, response_body
)
if blocked:
response_content = json.dumps(processed_response).encode(UTF_8)
response_code = 400
# Build response headers
response_headers = {"X-Proxied-By": "mcp-gateway", **response.headers}
if MCP_SESSION_ID_HEADER not in response.headers:
response_headers[MCP_SESSION_ID_HEADER] = session_id
return Response(
content=response_content,
status_code=response_code,
headers=response_headers,
)
async def _handle_streaming_response(
self, session_id: str, is_initialization_request: bool, response: httpx.Response
) -> StreamingResponse:
"""Handle streaming response from MCP server."""
async def event_generator():
buffer = ""
async for line in response.aiter_lines():
stripped_line = line.strip()
if not stripped_line:
break
if buffer:
response_body = json.loads(stripped_line.split("data: ")[1].strip())
if not is_initialization_request:
(
processed_response,
blocked,
) = await self.process_incoming_response(
session_id, response_body
)
if blocked:
yield f"{buffer}\ndata: {json.dumps(processed_response)}\n\n"
break
else:
self._update_mcp_response_info_in_session(
session_id, response_body, False
)
yield f"{buffer}\n{stripped_line}\n\n"
buffer = ""
else:
buffer = stripped_line
# Build response headers
response_headers = {"X-Proxied-By": "mcp-gateway", **response.headers}
if MCP_SESSION_ID_HEADER not in response.headers:
response_headers[MCP_SESSION_ID_HEADER] = session_id
return StreamingResponse(
event_generator(),
media_type=CONTENT_TYPE_EVENT_STREAM,
headers=response_headers,
)
def _update_mcp_response_info_in_session(
self, session_id: str, response_body: dict, is_json_response: bool
) -> None:
"""Update MCP response info in session metadata."""
session = self.session_store.get_session(session_id)
self.update_mcp_server_in_session_metadata(session, response_body)
session.attributes.metadata["is_stateless_http_server"] = session_id.startswith(
INVARIANT_SESSION_ID_PREFIX
)
session.attributes.metadata["server_response_type"] = (
"json" if is_json_response else "sse"
)
def _get_headers_for_mcp_post_and_delete(self, request: Request) -> dict:
"""Get filtered headers for MCP server requests."""
filtered_headers = {}
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_POST_AND_DELETE_HEADERS and not (
k.lower() == MCP_SESSION_ID_HEADER
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
):
filtered_headers[k] = v
return filtered_headers
def _get_session_id(self, request: Request) -> str:
"""Extract session ID from request headers."""
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
if not session_id:
raise HTTPException(status_code=400, detail="Missing mcp-session-id header")
return session_id
def _get_mcp_server_endpoint(self, request: Request) -> str:
"""Get MCP server endpoint URL."""
return self.get_mcp_server_base_url(request) + "/mcp/"
def _is_initialization_request(self, request_data: dict[str, Any]) -> bool:
"""Check if request is an initialization request."""
return (
request_data.get("method") in ["initialize", "notifications/initialized"]
and "jsonrpc" in request_data
)