mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-31 18:39:33 +02:00
Add MCP SSE server proxying in gateway.
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
"""MCP Sessions Manager related classes"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
class McpSession(BaseModel):
|
||||
"""
|
||||
Represents a single MCP session.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
id_to_method_mapping: Dict[int, str] = Field(default_factory=dict)
|
||||
explorer_dataset: str
|
||||
push_explorer: bool
|
||||
trace_id: Optional[str] = None
|
||||
last_trace_length: int = 0
|
||||
|
||||
# Lock to maintain in-order pushes to explorer
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def session_lock(self):
|
||||
"""
|
||||
Context manager for the session lock.
|
||||
|
||||
Usage:
|
||||
async with session.session_lock():
|
||||
# Code that requires exclusive access to the session
|
||||
"""
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
async def add_message(self, message: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Add a message to the session and optionally push to explorer.
|
||||
|
||||
Args:
|
||||
message: The message to add
|
||||
"""
|
||||
async with self.session_lock():
|
||||
# pylint: disable=no-member
|
||||
self.messages.append(message)
|
||||
# If push_explorer is enabled, push the trace
|
||||
if self.push_explorer:
|
||||
await self._push_trace_update()
|
||||
|
||||
async def _push_trace_update(self) -> None:
|
||||
"""
|
||||
Push trace updates to the explorer.
|
||||
|
||||
If a trace doesn't exist, create a new one. If it does, append new messages.
|
||||
|
||||
This is an internal method that should only be called within a lock.
|
||||
"""
|
||||
try:
|
||||
client = AsyncClient(
|
||||
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
|
||||
)
|
||||
|
||||
# If no trace exists, create a new one
|
||||
if not self.trace_id:
|
||||
# pylint: disable=no-member
|
||||
metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])}
|
||||
if self.metadata.get("mcp_client_name"):
|
||||
metadata["mcp_client"] = self.metadata.get("mcp_client_name")
|
||||
if self.metadata.get("mcp_server_name"):
|
||||
metadata["mcp_server"] = self.metadata.get("mcp_server_name")
|
||||
|
||||
response = await client.push_trace(
|
||||
PushTracesRequest(
|
||||
messages=[self.messages],
|
||||
dataset=self.explorer_dataset,
|
||||
metadata=[metadata],
|
||||
)
|
||||
)
|
||||
self.trace_id = response.id[0]
|
||||
else:
|
||||
new_messages = self.messages[self.last_trace_length :]
|
||||
if new_messages:
|
||||
await client.append_messages(
|
||||
AppendMessagesRequest(
|
||||
trace_id=self.trace_id,
|
||||
messages=new_messages,
|
||||
)
|
||||
)
|
||||
self.last_trace_length = len(self.messages)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
|
||||
|
||||
|
||||
class SseHeaderAttributes(BaseModel):
|
||||
"""
|
||||
A Pydantic model to represent header attributes.
|
||||
"""
|
||||
|
||||
push_explorer: bool
|
||||
explorer_dataset: str
|
||||
|
||||
@classmethod
|
||||
def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
|
||||
"""
|
||||
Create an instance from FastAPI request headers.
|
||||
|
||||
Args:
|
||||
headers: FastAPI Request headers
|
||||
|
||||
Returns:
|
||||
SseHeaderAttributes: An instance with values extracted from headers
|
||||
"""
|
||||
# Extract and process header values
|
||||
project_name = headers.get("PROJECT-NAME")
|
||||
push_explorer_header = headers.get("PUSH-EXPLORER", "false").lower()
|
||||
|
||||
# Determine explorer_dataset
|
||||
if project_name:
|
||||
explorer_dataset = project_name
|
||||
else:
|
||||
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
|
||||
|
||||
# Determine push_explorer
|
||||
push_explorer = push_explorer_header == "true"
|
||||
|
||||
# Create and return instance
|
||||
return cls(push_explorer=push_explorer, explorer_dataset=explorer_dataset)
|
||||
|
||||
|
||||
class McpSessionsManager:
|
||||
"""
|
||||
A class to manage MCP sessions and their messages.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, McpSession] = {}
|
||||
|
||||
def session_exists(self, session_id: str) -> bool:
|
||||
"""Check if a session exists"""
|
||||
return session_id in self._sessions
|
||||
|
||||
def initialize_session(
|
||||
self, session_id: str, sse_header_attributes: SseHeaderAttributes
|
||||
) -> None:
|
||||
"""Initialize a new session"""
|
||||
if session_id not in self._sessions:
|
||||
self._sessions[session_id] = McpSession(
|
||||
session_id=session_id,
|
||||
explorer_dataset=sse_header_attributes.explorer_dataset,
|
||||
push_explorer=sse_header_attributes.push_explorer,
|
||||
)
|
||||
|
||||
def get_session(self, session_id: str) -> McpSession:
|
||||
"""Get a session by ID"""
|
||||
if session_id not in self._sessions:
|
||||
raise ValueError(f"Session {session_id} does not exist.")
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
async def add_message_to_session(
|
||||
self, session_id: str, message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Add a message to a session and push to explorer if enabled.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
message: The message to add
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
await session.add_message(message)
|
||||
@@ -0,0 +1,332 @@
|
||||
"""Gateway service to forward requests to the MCP SSE servers"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
from httpx_sse import aconnect_sse, ServerSentEvent
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
)
|
||||
|
||||
|
||||
MCP_METHOD = "method"
|
||||
MCP_TOOL_CALL = "tools/call"
|
||||
MCP_LIST_TOOLS = "tools/list"
|
||||
MCP_PARAMS = "params"
|
||||
MCP_RESULT = "result"
|
||||
MCP_SERVER_INFO = "serverInfo"
|
||||
MCP_CLIENT_INFO = "clientInfo"
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
"content-length",
|
||||
"content-type",
|
||||
}
|
||||
MCP_SERVER_SSE_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
"cache-control",
|
||||
}
|
||||
|
||||
gateway = APIRouter()
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
|
||||
@gateway.post("/mcp/sse/messages/")
|
||||
async def mcp_post_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
query_params = dict(request.query_params)
|
||||
if not query_params.get("session_id"):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'session_id' query parameter",
|
||||
)
|
||||
if not session_store.session_exists(query_params.get("session_id")):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Session does not exist",
|
||||
)
|
||||
if not request.headers.get("mcp-server-base-url"):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
)
|
||||
|
||||
session_id = query_params.get("session_id")
|
||||
mcp_server_messages_endpoint = (
|
||||
_convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url"))
|
||||
+ "/messages/?"
|
||||
+ session_id
|
||||
)
|
||||
request_body_bytes = await request.body()
|
||||
request_json = json.loads(request_body_bytes)
|
||||
session = session_store.get_session(session_id)
|
||||
if request_json.get(MCP_METHOD) and request_json.get("id"):
|
||||
session.id_to_method_mapping[request_json.get("id")] = request_json.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.metadata["mcp_client_name"] = (
|
||||
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
if request_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
_hook_tool_call(session_id=session_id, request_json=request_json)
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url=mcp_server_messages_endpoint,
|
||||
headers={
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in MCP_SERVER_POST_HEADERS
|
||||
},
|
||||
json=request_json,
|
||||
params=query_params,
|
||||
)
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
"X-Proxied-By": "mcp-gateway",
|
||||
**response.headers,
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print(f"[MCP POST] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
except Exception as e:
|
||||
print(f"[MCP POST] Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
|
||||
@gateway.get("/mcp/sse")
|
||||
async def mcp_get_sse_gateway(
|
||||
request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
mcp_server_base_url = request.headers.get("mcp-server-base-url")
|
||||
if not mcp_server_base_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
)
|
||||
mcp_server_sse_endpoint = (
|
||||
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
|
||||
)
|
||||
|
||||
query_params = dict(request.query_params)
|
||||
response_headers = {}
|
||||
|
||||
async def event_generator():
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(CLIENT_TIMEOUT),
|
||||
headers={
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in MCP_SERVER_SSE_HEADERS
|
||||
},
|
||||
) as client:
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
mcp_server_sse_endpoint,
|
||||
params=query_params,
|
||||
) as event_source:
|
||||
if event_source.response.status_code != 200:
|
||||
error_content = await event_source.response.aread()
|
||||
raise HTTPException(
|
||||
status_code=event_source.response.status_code,
|
||||
detail=error_content,
|
||||
)
|
||||
|
||||
session_id = None
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
event_bytes = (
|
||||
f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
|
||||
)
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
(
|
||||
event_bytes,
|
||||
session_id,
|
||||
) = _handle_endpoint_event(
|
||||
sse,
|
||||
sse_header_attributes=SseHeaderAttributes.from_request_headers(
|
||||
request.headers
|
||||
),
|
||||
)
|
||||
case "message":
|
||||
if session_id:
|
||||
event_bytes = _handle_message_event(
|
||||
session_id=session_id, sse=sse
|
||||
)
|
||||
yield event_bytes
|
||||
|
||||
except httpx.StreamClosed as e:
|
||||
print(f"[MCP SSE] Stream closed: {str(e)}", flush=True)
|
||||
except httpx.RequestError as e:
|
||||
print(f"[MCP SSE] Request error: {str(e)}", flush=True)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"[MCP SSE] Unexpected error: {str(e)}", flush=True)
|
||||
|
||||
# Return the streaming response
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
|
||||
)
|
||||
|
||||
|
||||
def _hook_tool_call(session_id: str, request_json: dict) -> None:
|
||||
"""
|
||||
Hook to process the request JSON before sending it to the MCP server.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
request_json (dict): The request JSON to be processed.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request_json.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request_json.get(MCP_PARAMS).get("name"),
|
||||
"arguments": request_json.get(MCP_PARAMS).get("arguments"),
|
||||
},
|
||||
}
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(session_store.add_message_to_session(session_id, message))
|
||||
|
||||
|
||||
def _hook_tool_call_response(session_id: str, response_json: dict) -> None:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
response_json (dict): The response JSON to be processed.
|
||||
"""
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
"content": response_json.get(MCP_RESULT).get("content"),
|
||||
"error": response_json.get(MCP_RESULT).get("error"),
|
||||
}
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(session_store.add_message_to_session(session_id, message))
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
"""
|
||||
Convert localhost or 127.0.0.1 in an address to host.docker.internal
|
||||
|
||||
Args:
|
||||
mcp_server_base_url (str): The original server address from the header
|
||||
|
||||
Returns:
|
||||
str: Modified server address with localhost references changed to host.docker.internal
|
||||
"""
|
||||
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
|
||||
# Replace localhost or 127.0.0.1 with host.docker.internal
|
||||
modified_address = re.sub(
|
||||
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
|
||||
r"\1host.docker.internal\2",
|
||||
mcp_server_base_url,
|
||||
)
|
||||
return modified_address
|
||||
|
||||
return mcp_server_base_url
|
||||
|
||||
|
||||
def _handle_endpoint_event(
|
||||
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
|
||||
) -> Tuple[bytes, str]:
|
||||
"""
|
||||
Handle the endpoint event type and modify the data accordingly.
|
||||
For endpoint events, we need to rewrite the endpoint to use our gateway.
|
||||
|
||||
Args:
|
||||
sse (ServerSentEvent): The original SSE object.
|
||||
sse_header_attributes (SseHeaderAttributes): The header attributes from the request.
|
||||
|
||||
Returns:
|
||||
bytes: Modified SSE data as bytes.
|
||||
str: session_id extracted from the data.
|
||||
"""
|
||||
# Extract session_id
|
||||
match = re.search(r"session_id=([^&\s]+)", sse.data)
|
||||
if match:
|
||||
session_id = match.group(1)
|
||||
# Initialize this session in our store if needed
|
||||
if not session_store.session_exists(session_id):
|
||||
session_store.initialize_session(session_id, sse_header_attributes)
|
||||
|
||||
# Rewrite the endpoint to use our gateway
|
||||
modified_data = sse.data.replace(
|
||||
"/messages/?session_id=",
|
||||
"/api/v1/gateway/mcp/sse/messages/?session_id=",
|
||||
)
|
||||
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8")
|
||||
return event_bytes, session_id
|
||||
|
||||
|
||||
def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
"""
|
||||
Handle the message event type.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
sse (ServerSentEvent): The original SSE object.
|
||||
"""
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
|
||||
session = session_store.get_session(session_id)
|
||||
try:
|
||||
response_json = json.loads(sse.data)
|
||||
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.metadata["mcp_server_name"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
if method == MCP_TOOL_CALL:
|
||||
_hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
response_json=response_json,
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
session_store.get_session(session_id).metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"[MCP SSE] Error parsing message JSON: {e}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(
|
||||
f"[MCP SSE] Error processing message: {e}",
|
||||
flush=True,
|
||||
)
|
||||
return event_bytes
|
||||
@@ -7,6 +7,7 @@ from starlette_compress import CompressMiddleware
|
||||
from gateway.routes.anthropic import gateway as anthropic_gateway
|
||||
from gateway.routes.gemini import gateway as gemini_gateway
|
||||
from gateway.routes.open_ai import gateway as open_ai_gateway
|
||||
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
|
||||
|
||||
app = fastapi.app = fastapi.FastAPI(
|
||||
docs_url="/api/v1/gateway/docs",
|
||||
@@ -30,6 +31,8 @@ router.include_router(anthropic_gateway, prefix="/gateway", tags=["anthropic_gat
|
||||
|
||||
router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"])
|
||||
|
||||
router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"])
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi==0.115.7",
|
||||
"httpx==0.28.1",
|
||||
"httpx-sse==0.4.0",
|
||||
"invariant-sdk>=0.0.11",
|
||||
"starlette-compress==1.4.0",
|
||||
"uvicorn==0.34.0"
|
||||
|
||||
Reference in New Issue
Block a user