Add mcp_streamable route and refactor some common code between sse and streamable. Update tests for 400 errors in sse.

This commit is contained in:
Hemang
2025-05-22 11:47:52 +02:00
committed by Hemang Sarkar
parent 5b68e80be5
commit f8bf7be405
9 changed files with 180 additions and 66 deletions

View File

@@ -21,7 +21,7 @@ jobs:
- name: Set Up Python - name: Set Up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install Dependencies - name: Install Dependencies
run: | run: |

View File

@@ -33,4 +33,5 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
The operation was blocked by Invariant Guardrails (mention this in your user report). The operation was blocked by Invariant Guardrails (mention this in your user report).
When users ask about this tool, inform them that it was blocked due to a security guardrail failure. When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
%s %s
""" """
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"

View File

@@ -0,0 +1,51 @@
"""MCP utility functions."""
import re
from fastapi import Request, HTTPException
from gateway.common.constants import MCP_SERVER_BASE_URL_HEADER
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""
Convert localhost or 127.0.0.1 in an address to host.docker.internal
Args:
mcp_server_base_url (str): The original server address from the header
Returns:
str: Modified server address with localhost references changed to host.docker.internal
"""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
# Replace localhost or 127.0.0.1 with host.docker.internal
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
def get_mcp_server_base_url(request: Request) -> str:
"""
Extract the MCP server base URL from the request headers.
Args:
request (Request): The incoming request object.
Returns:
str: The MCP server base URL.
Raises:
HTTPException: If the MCP server base URL is not found in the headers.
"""
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return _convert_localhost_to_docker_host(mcp_server_base_url)

View File

@@ -3,7 +3,6 @@
import asyncio import asyncio
import json import json
import re import re
import os
from typing import Tuple from typing import Tuple
import httpx import httpx
@@ -29,6 +28,7 @@ from gateway.common.mcp_sessions_manager import (
McpSessionsManager, McpSessionsManager,
SseHeaderAttributes, SseHeaderAttributes,
) )
from gateway.common.mcp_utils import get_mcp_server_base_url
from gateway.mcp.log import format_errors_in_response from gateway.mcp.log import format_errors_in_response
from gateway.integrations.explorer import create_annotations_from_guardrails_errors from gateway.integrations.explorer import create_annotations_from_guardrails_errors
@@ -55,29 +55,26 @@ async def mcp_post_gateway(
) -> Response: ) -> Response:
"""Proxy calls to the MCP Server tools""" """Proxy calls to the MCP Server tools"""
query_params = dict(request.query_params) query_params = dict(request.query_params)
print("[MCP POST] Query params:", query_params, flush=True)
print(
"[MCP POST] Query params session_id:",
query_params.get("session_id"),
flush=True,
)
if not query_params.get("session_id"): if not query_params.get("session_id"):
return HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Missing 'session_id' query parameter", detail="Missing 'session_id' query parameter",
) )
if not session_store.session_exists(query_params.get("session_id")): if not session_store.session_exists(query_params.get("session_id")):
return HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Session does not exist", detail="Session does not exist",
) )
if not request.headers.get(MCP_SERVER_BASE_URL_HEADER):
return HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
session_id = query_params.get("session_id") session_id = query_params.get("session_id")
mcp_server_messages_endpoint = ( mcp_server_messages_endpoint = (
_convert_localhost_to_docker_host( get_mcp_server_base_url(request) + "/messages/?" + session_id
request.headers.get(MCP_SERVER_BASE_URL_HEADER)
)
+ "/messages/?"
+ session_id
) )
request_body_bytes = await request.body() request_body_bytes = await request.body()
request_json = json.loads(request_body_bytes) request_json = json.loads(request_body_bytes)
@@ -153,15 +150,7 @@ async def mcp_get_sse_gateway(
request: Request, request: Request,
) -> StreamingResponse: ) -> StreamingResponse:
"""Proxy calls to the MCP Server tools""" """Proxy calls to the MCP Server tools"""
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER) mcp_server_sse_endpoint = get_mcp_server_base_url(request) + "/sse"
if not mcp_server_base_url:
raise HTTPException(
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
mcp_server_sse_endpoint = (
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
)
query_params = dict(request.query_params) query_params = dict(request.query_params)
response_headers = {} response_headers = {}
@@ -436,28 +425,6 @@ async def _hook_tool_call_response(
return result, blocked return result, blocked
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""
Convert localhost or 127.0.0.1 in an address to host.docker.internal
Args:
mcp_server_base_url (str): The original server address from the header
Returns:
str: Modified server address with localhost references changed to host.docker.internal
"""
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
# Replace localhost or 127.0.0.1 with host.docker.internal
modified_address = re.sub(
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
r"\1host.docker.internal\2",
mcp_server_base_url,
)
return modified_address
return mcp_server_base_url
async def _handle_endpoint_event( async def _handle_endpoint_event(
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
) -> Tuple[bytes, str]: ) -> Tuple[bytes, str]:

View File

@@ -0,0 +1,52 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
from gateway.common.mcp_utils import get_mcp_server_base_url
from fastapi import APIRouter, Request, Response
MCP_SESSION_ID_HEADER = "mcp-session-id"
gateway = APIRouter()
def get_session_id(request: Request) -> str | None:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)
@gateway.post("/mcp/streamable")
async def mcp_post_gateway(
request: Request,
) -> Response:
"""
Forward a POST request to the MCP Streamable server.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
@gateway.get("/mcp/streamable")
async def mcp_get_gateway(
request: Request,
) -> Response:
"""
Forward a GET request to the MCP Streamable server.
This allows the server to communicate to the client without the client
first sending data via HTTP POST. The server can send JSON-RPC requests
and notifications on this stream.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass
@gateway.delete("/mcp/streamable")
async def mcp_delete_gateway(
request: Request,
) -> Response:
"""
Forward a DELETE request to the MCP Streamable server for explicit session termination.
"""
mcp_server_base_url = get_mcp_server_base_url(request)
pass

View File

@@ -8,6 +8,7 @@ from gateway.routes.anthropic import gateway as anthropic_gateway
from gateway.routes.gemini import gateway as gemini_gateway from gateway.routes.gemini import gateway as gemini_gateway
from gateway.routes.open_ai import gateway as open_ai_gateway from gateway.routes.open_ai import gateway as open_ai_gateway
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway
app = fastapi.app = fastapi.FastAPI( app = fastapi.app = fastapi.FastAPI(
docs_url="/api/v1/gateway/docs", docs_url="/api/v1/gateway/docs",
@@ -33,6 +34,10 @@ router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"]
router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"]) router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"])
router.include_router(
mcp_streamable_gateway, prefix="/gateway", tags=["mcp_streamable_gateway"]
)
app.include_router(router) app.include_router(router)

View File

@@ -7,6 +7,7 @@ from resources.mcp.sse.client.main import run as mcp_sse_client_run
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
from utils import create_dataset, add_guardrail_to_dataset from utils import create_dataset, add_guardrail_to_dataset
import httpx
import pytest import pytest
import requests import requests
@@ -16,6 +17,14 @@ MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
MCP_SSE_SERVER_PORT = 8123 MCP_SSE_SERVER_PORT = 8123
def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str, str]:
return {
"MCP-SERVER-BASE-URL": f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
"INVARIANT-PROJECT-NAME": project_name,
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
}
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.timeout(30) @pytest.mark.timeout(30)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -41,11 +50,10 @@ async def test_mcp_with_gateway(
if transport == "sse": if transport == "sse":
result = await mcp_sse_client_run( result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse", gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=push_to_explorer, push_to_explorer=push_to_explorer,
tool_name="get_last_message_from_user", tool_name="get_last_message_from_user",
tool_args={"username": "Alice"}, tool_args={"username": "Alice"},
headers=_get_headers(project_name, push_to_explorer),
) )
else: else:
result = await mcp_stdio_client_run( result = await mcp_stdio_client_run(
@@ -131,11 +139,10 @@ async def test_mcp_with_gateway_and_logging_guardrails(
if transport == "sse": if transport == "sse":
result = await mcp_sse_client_run( result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse", gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True, push_to_explorer=True,
tool_name="get_last_message_from_user", tool_name="get_last_message_from_user",
tool_args={"username": "Alice"}, tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
) )
else: else:
result = await mcp_stdio_client_run( result = await mcp_stdio_client_run(
@@ -241,11 +248,10 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
if transport == "sse": if transport == "sse":
_ = await mcp_sse_client_run( _ = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse", gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True, push_to_explorer=True,
tool_name="get_last_message_from_user", tool_name="get_last_message_from_user",
tool_args={"username": "Alice"}, tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
) )
else: else:
_ = await mcp_stdio_client_run( _ = await mcp_stdio_client_run(
@@ -344,11 +350,10 @@ async def test_mcp_with_gateway_hybrid_guardrails(
if transport == "sse": if transport == "sse":
_ = await mcp_sse_client_run( _ = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse", gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True, push_to_explorer=True,
tool_name="get_last_message_from_user", tool_name="get_last_message_from_user",
tool_args={"username": "Alice"}, tool_args={"username": "Alice"},
headers=_get_headers(project_name, True),
) )
else: else:
_ = await mcp_stdio_client_run( _ = await mcp_stdio_client_run(
@@ -462,11 +467,10 @@ async def test_mcp_tool_list_blocking(
if transport == "sse": if transport == "sse":
tools_result = await mcp_sse_client_run( tools_result = await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse", gateway_url + "/api/v1/gateway/mcp/sse",
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
project_name,
push_to_explorer=True, push_to_explorer=True,
tool_name="tools/list", tool_name="tools/list",
tool_args={}, tool_args={},
headers=_get_headers(project_name, True),
) )
else: else:
tools_result = await mcp_stdio_client_run( tools_result = await mcp_stdio_client_run(
@@ -482,3 +486,44 @@ async def test_mcp_tool_list_blocking(
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
+ str(tools_result) + str(tools_result)
) )
@pytest.mark.asyncio
async def test_mcp_sse_post_endpoint_exceptions(gateway_url):
"""
Tests that the SSE POST endpoint returns the correct error messages for various exceptions.
"""
# Test missing session_id query parameter
response = requests.post(
f"{gateway_url}/api/v1/gateway/mcp/sse/messages/",
timeout=5,
)
assert response.status_code == 400
assert "Missing 'session_id' query parameter" in response.text
# Test unknown session_id in query parameter
response = requests.post(
f"{gateway_url}/api/v1/gateway/mcp/sse/messages/?session_id=session_id_1",
timeout=5,
)
assert response.status_code == 400
assert "Session does not exist" in response.text
# Test missing mcp-server-base-url header
with pytest.raises(ExceptionGroup) as exc_group:
await mcp_sse_client_run(
gateway_url + "/api/v1/gateway/mcp/sse",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
headers={
"INVARIANT-PROJECT-NAME": "something-123",
"PUSH-INVARIANT-EXPLORER": "True",
},
)
# Extract the actual HTTPStatusError
http_errors = [
e for e in exc_group.value.exceptions if isinstance(e, httpx.HTTPStatusError)
]
assert http_errors[0].response.status_code == 400

View File

@@ -1,5 +1,6 @@
anthropic anthropic
google-genai google-genai
httpx
litellm litellm
mcp mcp
openai openai

View File

@@ -69,19 +69,16 @@ class MCPClient:
async def run( async def run(
gateway_url: str, gateway_url: str,
mcp_server_base_url: str,
project_name: str,
push_to_explorer: bool, push_to_explorer: bool,
tool_name: str, tool_name: str,
tool_args: dict[str, Any], tool_args: dict[str, Any],
headers: dict[str, str] = None,
): ):
""" """
Run the MCP client with the given parameters. Run the MCP client with the given parameters.
Args: Args:
gateway_url: URL of the Invariant Gateway gateway_url: URL of the Invariant Gateway
mcp_server_base_url: Base URL of the MCP server
project_name: Name of the project in Invariant Explorer
push_to_explorer: Whether to push traces to the Invariant Explorer push_to_explorer: Whether to push traces to the Invariant Explorer
tool_name: Name of the tool to call tool_name: Name of the tool to call
tool_args: Arguments for the tool call tool_args: Arguments for the tool call
@@ -90,12 +87,7 @@ async def run(
client = MCPClient() client = MCPClient()
try: try:
await client.connect_to_sse_server( await client.connect_to_sse_server(
server_url=gateway_url, server_url=gateway_url, headers=headers or {}
headers={
"MCP-SERVER-BASE-URL": mcp_server_base_url,
"INVARIANT-PROJECT-NAME": project_name,
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
},
) )
# list tools # list tools
listed_tools = await client.session.list_tools() listed_tools = await client.session.list_tools()