Files
invariant-gateway/gateway/routes/mcp_sse.py

572 lines
22 KiB
Python

"""Gateway service to forward requests to the MCP SSE servers"""
import asyncio
import json
import re
from typing import Tuple
import httpx
from httpx_sse import aconnect_sse, ServerSentEvent
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_INFO,
MCP_CLIENT_INFO,
UTF_8,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
)
from gateway.common.mcp_utils import get_mcp_server_base_url
from gateway.mcp.log import format_errors_in_response
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
MCP_SERVER_POST_HEADERS = {
"connection",
"accept",
"content-length",
"content-type",
}
MCP_SERVER_SSE_HEADERS = {
"connection",
"accept",
"cache-control",
}
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
gateway = APIRouter()
session_store = McpSessionsManager()
@gateway.post("/mcp/sse/messages/")
async def mcp_post_gateway(
request: Request,
) -> Response:
"""Proxy calls to the MCP Server tools"""
query_params = dict(request.query_params)
print("[MCP POST] Query params:", query_params, flush=True)
print(
"[MCP POST] Query params session_id:",
query_params.get("session_id"),
flush=True,
)
if not query_params.get("session_id"):
raise HTTPException(
status_code=400,
detail="Missing 'session_id' query parameter",
)
if not session_store.session_exists(query_params.get("session_id")):
raise HTTPException(
status_code=400,
detail="Session does not exist",
)
session_id = query_params.get("session_id")
mcp_server_messages_endpoint = (
get_mcp_server_base_url(request) + "/messages/?" + session_id
)
request_body_bytes = await request.body()
request_json = json.loads(request_body_bytes)
session = session_store.get_session(session_id)
if request_json.get(MCP_METHOD) and request_json.get("id"):
session.id_to_method_mapping[request_json.get("id")] = request_json.get(
MCP_METHOD
)
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.metadata["mcp_client"] = (
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
if request_json.get(MCP_METHOD) == MCP_TOOL_CALL:
# Intercept and potentially block the request
hook_tool_call_result, is_blocked = await _hook_tool_call(
session_id=session_id, request_json=request_json
)
if is_blocked:
# Add the error message to the session.
# The error message is sent back to the client using the SSE stream.
await session.add_pending_error_message(hook_tool_call_result)
return Response(content="Accepted", status_code=202)
elif request_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
# Intercept and potentially block the request
hook_tool_call_result, is_blocked = await _hook_tool_call(
session_id=session_id,
request_json={
"id": request_json.get("id"),
"method": MCP_LIST_TOOLS,
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
},
)
if is_blocked:
# Add the error message to the session.
# The error message is sent back to the client using the SSE stream.
await session.add_pending_error_message(hook_tool_call_result)
return Response(content="Accepted", status_code=202)
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.post(
url=mcp_server_messages_endpoint,
headers={
k: v
for k, v in request.headers.items()
if k.lower() in MCP_SERVER_POST_HEADERS
},
json=request_json,
params=query_params,
)
return Response(
content=response.content,
status_code=response.status_code,
headers={
"X-Proxied-By": "mcp-gateway",
**response.headers,
},
)
except httpx.RequestError as e:
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
except Exception as e:
print(f"[MCP POST] Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Unexpected error") from e
@gateway.get("/mcp/sse")
async def mcp_get_sse_gateway(
request: Request,
) -> StreamingResponse:
"""Proxy calls to the MCP Server tools"""
mcp_server_sse_endpoint = get_mcp_server_base_url(request) + "/sse"
query_params = dict(request.query_params)
response_headers = {}
filtered_headers = {
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
}
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
async def event_generator():
"""
Generate a merged stream of MCP server events and pending error messages.
The pending error messages are added in the POST messages handler.
This function runs in a loop, yielding events as they arrive.
"""
mcp_server_events_queue = asyncio.Queue()
pending_error_messages_queue = asyncio.Queue()
tasks = set()
session_id = None
try:
# MCP Server Events Processor
async def process_mcp_server_events():
"""Connect to MCP server and process its events."""
nonlocal session_id
async with httpx.AsyncClient(
timeout=httpx.Timeout(CLIENT_TIMEOUT)
) as client:
try:
async with aconnect_sse(
client,
"GET",
mcp_server_sse_endpoint,
headers=filtered_headers,
params=query_params,
) as event_source:
if event_source.response.status_code != 200:
error_content = await event_source.response.aread()
raise HTTPException(
status_code=event_source.response.status_code,
detail=error_content,
)
async for sse in event_source.aiter_sse():
if sse.event == "endpoint":
(
event_bytes,
extracted_id,
) = await _handle_endpoint_event(
sse, sse_header_attributes
)
session_id = extracted_id
if (
session_id
and "process_error_messages_task"
not in locals()
):
process_error_messages_task = (
asyncio.create_task(
_check_for_pending_error_messages(
session_id,
pending_error_messages_queue,
)
)
)
tasks.add(process_error_messages_task)
process_error_messages_task.add_done_callback(
tasks.discard
)
elif sse.event == "message" and session_id:
# Process message event
event_bytes = await _handle_message_event(
session_id, sse
)
else:
# Pass through other event types
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
UTF_8
)
# Put the processed event in the queue
await mcp_server_events_queue.put(event_bytes)
except httpx.StreamClosed as e:
print(f"Server stream closed: {e}", flush=True)
except Exception as e: # pylint: disable=broad-except
print(f"Error processing server events: {e}", flush=True)
# Start server events processor
mcp_server_events_task = asyncio.create_task(process_mcp_server_events())
tasks.add(mcp_server_events_task)
mcp_server_events_task.add_done_callback(tasks.discard)
# Main event loop: merge MCP server events and pending error messages
while True:
# Create futures for both queues
mcp_server_event_future = asyncio.create_task(
mcp_server_events_queue.get()
)
pending_error_message_future = asyncio.create_task(
pending_error_messages_queue.get()
)
# Wait for either queue to have an item, with timeout
done, pending = await asyncio.wait(
[mcp_server_event_future, pending_error_message_future],
return_when=asyncio.FIRST_COMPLETED,
timeout=0.25,
)
for future in pending:
future.cancel()
# Timeout occurred and no future completed.
if not done:
continue
for future in done:
try:
event = await future
yield event
except asyncio.CancelledError:
# Future was cancelled, continue
continue
finally:
# Clean up all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to complete
if tasks:
await asyncio.wait(tasks, timeout=2)
# Return the streaming response
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
Args:
session_id (str): The session ID associated with the request.
request_json (dict): The request JSON to be processed.
"""
tool_call = {
"id": f"call_{request_json.get('id')}",
"type": "function",
"function": {
"name": request_json.get(MCP_PARAMS).get("name"),
"arguments": request_json.get(MCP_PARAMS).get("arguments"),
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails - this blocks until completion
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
session_store.add_message_to_session(
session_id=session_id,
message=message,
guardrails_result=guardrails_result,
)
)
return {
"jsonrpc": "2.0",
"id": request_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
# Push trace to the explorer - don't block on its response
await session_store.add_message_to_session(session_id, message, guardrails_result)
return request_json, False
async def _hook_tool_call_response(
session_id: str, response_json: dict, is_tools_list=False
) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
response_json (dict): The response JSON to be processed.
Returns:
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
blocked = False
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
"content": response_json.get(MCP_RESULT).get("content"),
"error": response_json.get(MCP_RESULT).get("error"),
}
result = response_json
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
if (
guardrails_result
and guardrails_result.get("errors", [])
and _check_if_new_errors(session_id, guardrails_result)
):
blocked = True
# If the request is blocked, return a message indicating the block reason.
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
else:
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrails_result["errors"]),
# no parameters
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_json["result"]["tools"]
]
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result, blocked
async def _handle_endpoint_event(
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
) -> Tuple[bytes, str]:
"""
Handle the endpoint event type and modify the data accordingly.
For endpoint events, we need to rewrite the endpoint to use our gateway.
Args:
sse (ServerSentEvent): The original SSE object.
sse_header_attributes (SseHeaderAttributes): The header attributes from the request.
Returns:
bytes: Modified SSE data as bytes.
str: session_id extracted from the data.
"""
# Extract session_id
match = re.search(r"session_id=([^&\s]+)", sse.data)
if match:
session_id = match.group(1)
# Initialize this session in our store if needed
if not session_store.session_exists(session_id):
await session_store.initialize_session(session_id, sse_header_attributes)
# Rewrite the endpoint to use our gateway
modified_data = sse.data.replace(
"/messages/?session_id=",
"/api/v1/gateway/mcp/sse/messages/?session_id=",
)
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode(UTF_8)
return event_bytes, session_id
async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
"""
Handle the message event type.
Args:
session_id (str): The session ID associated with the request.
sse (ServerSentEvent): The original SSE object.
"""
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8)
session = session_store.get_session(session_id)
try:
response_json = json.loads(sse.data)
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.metadata["mcp_server"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
method = session.id_to_method_mapping.get(response_json.get("id"))
if method == MCP_TOOL_CALL:
hook_tool_call_response, blocked = await _hook_tool_call_response(
session_id=session_id,
response_json=response_json,
)
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
UTF_8
)
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
hook_tool_call_response, blocked = await _hook_tool_call_response(
session_id=session_id,
response_json={
"id": response_json.get("id"),
"result": {
"content": json.dumps(
response_json.get(MCP_RESULT).get("tools")
),
"tools": response_json.get(MCP_RESULT).get("tools"),
},
},
is_tools_list=True,
)
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
UTF_8
)
except json.JSONDecodeError as e:
print(
f"[MCP SSE] Error parsing message JSON: {e}",
flush=True,
)
except Exception as e: # pylint: disable=broad-except
print(
f"[MCP SSE] Error processing message: {e}",
flush=True,
)
return event_bytes
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def _check_for_pending_error_messages(
session_id: str, pending_error_messages_queue: asyncio.Queue
):
"""Periodically check for and enqueue pending error messages."""
try:
while True:
try:
session = session_store.get_session(session_id)
error_messages = await session.get_pending_error_messages()
for error_message in error_messages:
error_bytes = (
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
UTF_8
)
)
await pending_error_messages_queue.put(error_bytes)
await asyncio.sleep(1)
except Exception as e: # pylint: disable=broad-except
print(f"Error checking for messages: {e}", flush=True)
await asyncio.sleep(1)
except asyncio.CancelledError:
# Task was cancelled, exit gracefully
return