Add mcp_streamable route and refactor some common code between sse and streamable. Update tests for 400 errors in sse.

This commit is contained in:
Hemang
2025-05-22 11:47:52 +02:00
committed by Hemang Sarkar
parent 5b68e80be5
commit f8bf7be405
9 changed files with 180 additions and 66 deletions

View File

@@ -21,7 +21,7 @@ jobs:
- name: Set Up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Dependencies
run: |

View File

@@ -34,3 +34,4 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
%s
"""
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"

View File

@@ -0,0 +1,51 @@
"""MCP utility functions."""
import re
from fastapi import Request, HTTPException
from gateway.common.constants import MCP_SERVER_BASE_URL_HEADER
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""
Convert localhost or 127.0.0.1 in an address to host.docker.internal
Args:
mcp_server_base_url (str): The original server address from the header
Returns:
str: Modified server address with localhost references changed to host.docker.internal
"""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
# Replace localhost or 127.0.0.1 with host.docker.internal
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
def get_mcp_server_base_url(request: Request) -> str:
"""
Extract the MCP server base URL from the request headers.
Args:
request (Request): The incoming request object.
Returns:
str: The MCP server base URL.
Raises:
HTTPException: If the MCP server base URL is not found in the headers.
"""
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return _convert_localhost_to_docker_host(mcp_server_base_url)

View File

@@ -3,7 +3,6 @@
import asyncio
import json
import re
import os
from typing import Tuple
import httpx
@@ -29,6 +28,7 @@ from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
)
from gateway.common.mcp_utils import get_mcp_server_base_url
from gateway.mcp.log import format_errors_in_response
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
@@ -55,29 +55,26 @@ async def mcp_post_gateway(
) -> Response:
"""Proxy calls to the MCP Server tools"""
query_params = dict(request.query_params)
print("[MCP POST] Query params:", query_params, flush=True)
print(
"[MCP POST] Query params session_id:",
query_params.get("session_id"),
flush=True,
)
if not query_params.get("session_id"):
return HTTPException(
raise HTTPException(
status_code=400,
detail="Missing 'session_id' query parameter",
)
if not session_store.session_exists(query_params.get("session_id")):
return HTTPException(
raise HTTPException(
status_code=400,
detail="Session does not exist",
)
if not request.headers.get(MCP_SERVER_BASE_URL_HEADER):
return HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
session_id = query_params.get("session_id")
mcp_server_messages_endpoint = (
_convert_localhost_to_docker_host(
request.headers.get(MCP_SERVER_BASE_URL_HEADER)
)
+ "/messages/?"
+ session_id
get_mcp_server_base_url(request) + "/messages/?" + session_id
)
request_body_bytes = await request.body()
request_json = json.loads(request_body_bytes)
@@ -153,15 +150,7 @@ async def mcp_get_sse_gateway(
request: Request,
) -> StreamingResponse:
"""Proxy calls to the MCP Server tools"""
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
mcp_server_sse_endpoint = (
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
)
mcp_server_sse_endpoint = get_mcp_server_base_url(request) + "/sse"
query_params = dict(request.query_params)
response_headers = {}
@@ -436,28 +425,6 @@ async def _hook_tool_call_response(
return result, blocked
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""
Convert localhost or 127.0.0.1 in an address to host.docker.internal
Args:
mcp_server_base_url (str): The original server address from the header
Returns:
str: Modified server address with localhost references changed to host.docker.internal
"""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
# Replace localhost or 127.0.0.1 with host.docker.internal
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
async def _handle_endpoint_event(
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
) -> Tuple[bytes, str]:

View File

@@ -0,0 +1,52 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
from gateway.common.mcp_utils import get_mcp_server_base_url
from fastapi import APIRouter, Request, Response
MCP_SESSION_ID_HEADER = "mcp-session-id"
gateway = APIRouter()
def get_session_id(request: Request) -> str | None:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)
@gateway.post("/mcp/streamable")
async def mcp_post_gateway(
request: Request,
) -> Response:
"""
Forward a POST request to the MCP Streamable server.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
@gateway.get("/mcp/streamable")
async def mcp_get_gateway(
request: Request,
) -> Response:
"""
Forward a GET request to the MCP Streamable server.
This allows the server to communicate to the client without the client
first sending data via HTTP POST. The server can send JSON-RPC requests
and notifications on this stream.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
@gateway.delete("/mcp/streamable")
async def mcp_delete_gateway(
request: Request,
) -> Response:
"""
Forward a DELETE request to the MCP Streamable server for explicit session termination.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass

View File

@@ -8,6 +8,7 @@ from gateway.routes.anthropic import gateway as anthropic_gateway
from gateway.routes.gemini import gateway as gemini_gateway
from gateway.routes.open_ai import gateway as open_ai_gateway
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway
app = fastapi.app = fastapi.FastAPI(
docs_url="/api/v1/gateway/docs",
@@ -33,6 +34,10 @@ router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"]
router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"])
router.include_router(
mcp_streamable_gateway, prefix="/gateway", tags=["mcp_streamable_gateway"]
)
app.include_router(router)

View File

@@ -7,6 +7,7 @@ from resources.mcp.sse.client.main import run as mcp_sse_client_run
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
from utils import create_dataset, add_guardrail_to_dataset
import httpx
import pytest
import requests
@@ -16,6 +17,14 @@ MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
MCP_SSE_SERVER_PORT = 8123
def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str, str]:
return {
"MCP-SERVER-BASE-URL": f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
"INVARIANT-PROJECT-NAME": project_name,
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
}
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize(
@@ -41,11 +50,10 @@ async def test_mcp_with_gateway(
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=push_to_explorer,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, push_to_explorer),
)
else:
result = await mcp_stdio_client_run(
@@ -131,11 +139,10 @@ async def test_mcp_with_gateway_and_logging_guardrails(
if transport == "sse":
result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
)
else:
result = await mcp_stdio_client_run(
@@ -241,11 +248,10 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
if transport == "sse":
_ = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
)
else:
_ = await mcp_stdio_client_run(
@@ -344,11 +350,10 @@ async def test_mcp_with_gateway_hybrid_guardrails(
if transport == "sse":
_ = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
)
else:
_ = await mcp_stdio_client_run(
@@ -462,11 +467,10 @@ async def test_mcp_tool_list_blocking(
if transport == "sse":
tools_result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True,
tool_name="tools/list",
tool_args={},
headers=_get_headers(project_name, True),
)
else:
tools_result = await mcp_stdio_client_run(
@@ -482,3 +486,44 @@ async def test_mcp_tool_list_blocking(
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
+ str(tools_result)
)
@pytest.mark.asyncio
async def test_mcp_sse_post_endpoint_exceptions(gateway_url):
"""
Tests that the SSE POST endpoint returns the correct error messages for various exceptions.
"""
# Test missing session_id query parameter
response = requests.post(
f"{gateway_url}/api/v1/gateway/mcp/sse/messages/",
timeout=5,
)
assert response.status_code == 400
assert "Missing 'session_id' query parameter" in response.text
# Test unknown session_id in query parameter
response = requests.post(
f"{gateway_url}/api/v1/gateway/mcp/sse/messages/?session_id=session_id_1",
timeout=5,
)
assert response.status_code == 400
assert "Session does not exist" in response.text
# Test missing mcp-server-base-url header
with pytest.raises(ExceptionGroup) as exc_group:
await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers={
"INVARIANT-PROJECT-NAME": "something-123",
"PUSH-INVARIANT-EXPLORER": "True",
},
)
# Extract the actual HTTPStatusError
http_errors = [
e for e in exc_group.value.exceptions if isinstance(e, httpx.HTTPStatusError)
]
assert http_errors[0].response.status_code == 400

View File

@@ -1,5 +1,6 @@
anthropic
google-genai
httpx
litellm
mcp
openai

View File

@@ -69,19 +69,16 @@ class MCPClient:
async def run(
gateway_url: str,
mcp_server_base_url: str,
project_name: str,
push_to_explorer: bool,
tool_name: str,
tool_args: dict[str, Any],
headers: dict[str, str] = None,
):
"""
Run the MCP client with the given parameters.
Args:
gateway_url: URL of the Invariant Gateway
mcp_server_base_url: Base URL of the MCP server
project_name: Name of the project in Invariant Explorer
push_to_explorer: Whether to push traces to the Invariant Explorer
tool_name: Name of the tool to call
tool_args: Arguments for the tool call
@@ -90,12 +87,7 @@ async def run(
client = MCPClient()
try:
await client.connect_to_sse_server(
server_url=gateway_url,
headers={
"MCP-SERVER-BASE-URL": mcp_server_base_url,
"INVARIANT-PROJECT-NAME": project_name,
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
},
server_url=gateway_url, headers=headers or {}
)
# list tools
listed_tools = await client.session.list_tools()