diff --git a/.github/workflows/tests_ci.yml b/.github/workflows/tests_ci.yml index 8641eb7..5de1195 100644 --- a/.github/workflows/tests_ci.yml +++ b/.github/workflows/tests_ci.yml @@ -21,7 +21,7 @@ jobs: - name: Set Up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Install Dependencies run: | diff --git a/gateway/common/constants.py b/gateway/common/constants.py index 715c42a..df575dd 100644 --- a/gateway/common/constants.py +++ b/gateway/common/constants.py @@ -33,4 +33,5 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """ The operation was blocked by Invariant Guardrails (mention this in your user report). When users ask about this tool, inform them that it was blocked due to a security guardrail failure. %s - """ \ No newline at end of file + """ +MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url" diff --git a/gateway/common/mcp_utils.py b/gateway/common/mcp_utils.py new file mode 100644 index 0000000..c6f6f8a --- /dev/null +++ b/gateway/common/mcp_utils.py @@ -0,0 +1,51 @@ +"""MCP utility functions.""" + +import re + +from fastapi import Request, HTTPException + +from gateway.common.constants import MCP_SERVER_BASE_URL_HEADER + + +def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: + """ + Convert localhost or 127.0.0.1 in an address to host.docker.internal + + Args: + mcp_server_base_url (str): The original server address from the header + + Returns: + str: Modified server address with localhost references changed to host.docker.internal + """ + if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url: + # Replace localhost or 127.0.0.1 with host.docker.internal + modified_address = re.sub( + r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)", + r"\1host.docker.internal\2", + mcp_server_base_url, + ) + return modified_address + + return mcp_server_base_url + + +def get_mcp_server_base_url(request: Request) -> str: + """ + Extract the MCP server base URL from the request headers. + + Args: + request (Request): The incoming request object. + + Returns: + str: The MCP server base URL. + + Raises: + HTTPException: If the MCP server base URL is not found in the headers. + """ + mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER) + if not mcp_server_base_url: + raise HTTPException( + status_code=400, + detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", + ) + return _convert_localhost_to_docker_host(mcp_server_base_url) diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index 201c37b..8bcd43e 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -3,7 +3,6 @@ import asyncio import json import re -import os from typing import Tuple import httpx @@ -29,6 +28,7 @@ from gateway.common.mcp_sessions_manager import ( McpSessionsManager, SseHeaderAttributes, ) +from gateway.common.mcp_utils import get_mcp_server_base_url from gateway.mcp.log import format_errors_in_response from gateway.integrations.explorer import create_annotations_from_guardrails_errors @@ -55,29 +55,26 @@ async def mcp_post_gateway( ) -> Response: """Proxy calls to the MCP Server tools""" query_params = dict(request.query_params) + print("[MCP POST] Query params:", query_params, flush=True) + print( + "[MCP POST] Query params session_id:", + query_params.get("session_id"), + flush=True, + ) if not query_params.get("session_id"): - return HTTPException( + raise HTTPException( status_code=400, detail="Missing 'session_id' query parameter", ) if not session_store.session_exists(query_params.get("session_id")): - return HTTPException( + raise HTTPException( status_code=400, detail="Session does not exist", ) - if not request.headers.get(MCP_SERVER_BASE_URL_HEADER): - return HTTPException( - status_code=400, - detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", - ) session_id = query_params.get("session_id") mcp_server_messages_endpoint = ( - _convert_localhost_to_docker_host( - request.headers.get(MCP_SERVER_BASE_URL_HEADER) - ) - + "/messages/?" - + session_id + get_mcp_server_base_url(request) + "/messages/?" + session_id ) request_body_bytes = await request.body() request_json = json.loads(request_body_bytes) @@ -153,15 +150,7 @@ async def mcp_get_sse_gateway( request: Request, ) -> StreamingResponse: """Proxy calls to the MCP Server tools""" - mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER) - if not mcp_server_base_url: - raise HTTPException( - status_code=400, - detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", - ) - mcp_server_sse_endpoint = ( - _convert_localhost_to_docker_host(mcp_server_base_url) + "/sse" - ) + mcp_server_sse_endpoint = get_mcp_server_base_url(request) + "/sse" query_params = dict(request.query_params) response_headers = {} @@ -436,28 +425,6 @@ async def _hook_tool_call_response( return result, blocked -def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: - """ - Convert localhost or 127.0.0.1 in an address to host.docker.internal - - Args: - mcp_server_base_url (str): The original server address from the header - - Returns: - str: Modified server address with localhost references changed to host.docker.internal - """ - if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url: - # Replace localhost or 127.0.0.1 with host.docker.internal - modified_address = re.sub( - r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)", - r"\1host.docker.internal\2", - mcp_server_base_url, - ) - return modified_address - - return mcp_server_base_url - - async def _handle_endpoint_event( sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes ) -> Tuple[bytes, str]: diff --git a/gateway/routes/mcp_streamable.py b/gateway/routes/mcp_streamable.py new file mode 100644 index 0000000..f978387 --- /dev/null +++ b/gateway/routes/mcp_streamable.py @@ -0,0 +1,52 @@ +"""Gateway service to forward requests to the MCP Streamable HTTP servers""" + +from gateway.common.mcp_utils import get_mcp_server_base_url + +from fastapi import APIRouter, Request, Response + + +MCP_SESSION_ID_HEADER = "mcp-session-id" + +gateway = APIRouter() + + +def get_session_id(request: Request) -> str | None: + """Extract the session ID from request headers.""" + return request.headers.get(MCP_SESSION_ID_HEADER) + + +@gateway.post("/mcp/streamable") +async def mcp_post_gateway( + request: Request, +) -> Response: + """ + Forward a POST request to the MCP Streamable server. + """ + mcp_server_base_url = get_mcp_server_base_url(request) + pass + + +@gateway.get("/mcp/streamable") +async def mcp_get_gateway( + request: Request, +) -> Response: + """ + Forward a GET request to the MCP Streamable server. + + This allows the server to communicate to the client without the client + first sending data via HTTP POST. The server can send JSON-RPC requests + and notifications on this stream. + """ + mcp_server_base_url = get_mcp_server_base_url(request) + pass + + +@gateway.delete("/mcp/streamable") +async def mcp_delete_gateway( + request: Request, +) -> Response: + """ + Forward a DELETE request to the MCP Streamable server for explicit session termination. + """ + mcp_server_base_url = get_mcp_server_base_url(request) + pass diff --git a/gateway/serve.py b/gateway/serve.py index b8c8c5e..ebf8f58 100644 --- a/gateway/serve.py +++ b/gateway/serve.py @@ -8,6 +8,7 @@ from gateway.routes.anthropic import gateway as anthropic_gateway from gateway.routes.gemini import gateway as gemini_gateway from gateway.routes.open_ai import gateway as open_ai_gateway from gateway.routes.mcp_sse import gateway as mcp_sse_gateway +from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway app = fastapi.app = fastapi.FastAPI( docs_url="/api/v1/gateway/docs", @@ -33,6 +34,10 @@ router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"] router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"]) +router.include_router( + mcp_streamable_gateway, prefix="/gateway", tags=["mcp_streamable_gateway"] +) + app.include_router(router) diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index a4c6809..9faa994 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -7,6 +7,7 @@ from resources.mcp.sse.client.main import run as mcp_sse_client_run from resources.mcp.stdio.client.main import run as mcp_stdio_client_run from utils import create_dataset, add_guardrail_to_dataset +import httpx import pytest import requests @@ -16,6 +17,14 @@ MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server" MCP_SSE_SERVER_PORT = 8123 +def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str, str]: + return { + "MCP-SERVER-BASE-URL": f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}", + "INVARIANT-PROJECT-NAME": project_name, + "PUSH-INVARIANT-EXPLORER": str(push_to_explorer), + } + + @pytest.mark.asyncio @pytest.mark.timeout(30) @pytest.mark.parametrize( @@ -41,11 +50,10 @@ async def test_mcp_with_gateway( if transport == "sse": result = await mcp_sse_client_run( gateway_url + "/api/v1/gateway/mcp/sse", - f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}", - project_name, push_to_explorer=push_to_explorer, tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, + headers=_get_headers(project_name, push_to_explorer), ) else: result = await mcp_stdio_client_run( @@ -131,11 +139,10 @@ async def test_mcp_with_gateway_and_logging_guardrails( if transport == "sse": result = await mcp_sse_client_run( gateway_url + "/api/v1/gateway/mcp/sse", - f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}", - project_name, push_to_explorer=True, tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, + headers=_get_headers(project_name, True), ) else: result = await mcp_stdio_client_run( @@ -241,11 +248,10 @@ async def test_mcp_with_gateway_and_blocking_guardrails( if transport == "sse": _ = await mcp_sse_client_run( gateway_url + "/api/v1/gateway/mcp/sse", - f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}", - project_name, push_to_explorer=True, tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, + headers=_get_headers(project_name, True), ) else: _ = await mcp_stdio_client_run( @@ -344,11 +350,10 @@ async def test_mcp_with_gateway_hybrid_guardrails( if transport == "sse": _ = await mcp_sse_client_run( gateway_url + "/api/v1/gateway/mcp/sse", - f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}", - project_name, push_to_explorer=True, tool_name="get_last_message_from_user", tool_args={"username": "Alice"}, + headers=_get_headers(project_name, True), ) else: _ = await mcp_stdio_client_run( @@ -462,11 +467,10 @@ async def test_mcp_tool_list_blocking( if transport == "sse": tools_result = await mcp_sse_client_run( gateway_url + "/api/v1/gateway/mcp/sse", - f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}", - project_name, push_to_explorer=True, tool_name="tools/list", tool_args={}, + headers=_get_headers(project_name, True), ) else: tools_result = await mcp_stdio_client_run( @@ -482,3 +486,44 @@ async def test_mcp_tool_list_blocking( "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result) ) + + +@pytest.mark.asyncio +async def test_mcp_sse_post_endpoint_exceptions(gateway_url): + """ + Tests that the SSE POST endpoint returns the correct error messages for various exceptions. + """ + # Test missing session_id query parameter + response = requests.post( + f"{gateway_url}/api/v1/gateway/mcp/sse/messages/", + timeout=5, + ) + assert response.status_code == 400 + assert "Missing 'session_id' query parameter" in response.text + + # Test unknown session_id in query parameter + response = requests.post( + f"{gateway_url}/api/v1/gateway/mcp/sse/messages/?session_id=session_id_1", + timeout=5, + ) + assert response.status_code == 400 + assert "Session does not exist" in response.text + + # Test missing mcp-server-base-url header + with pytest.raises(ExceptionGroup) as exc_group: + await mcp_sse_client_run( + gateway_url + "/api/v1/gateway/mcp/sse", + push_to_explorer=True, + tool_name="get_last_message_from_user", + tool_args={"username": "Alice"}, + headers={ + "INVARIANT-PROJECT-NAME": "something-123", + "PUSH-INVARIANT-EXPLORER": "True", + }, + ) + + # Extract the actual HTTPStatusError + http_errors = [ + e for e in exc_group.value.exceptions if isinstance(e, httpx.HTTPStatusError) + ] + assert http_errors[0].response.status_code == 400 diff --git a/tests/integration/requirements.txt b/tests/integration/requirements.txt index 8c878b8..f848cdf 100644 --- a/tests/integration/requirements.txt +++ b/tests/integration/requirements.txt @@ -1,5 +1,6 @@ anthropic google-genai +httpx litellm mcp openai diff --git a/tests/integration/resources/mcp/sse/client/main.py b/tests/integration/resources/mcp/sse/client/main.py index e8ce332..098b67b 100644 --- a/tests/integration/resources/mcp/sse/client/main.py +++ b/tests/integration/resources/mcp/sse/client/main.py @@ -69,19 +69,16 @@ class MCPClient: async def run( gateway_url: str, - mcp_server_base_url: str, - project_name: str, push_to_explorer: bool, tool_name: str, tool_args: dict[str, Any], + headers: dict[str, str] = None, ): """ Run the MCP client with the given parameters. Args: gateway_url: URL of the Invariant Gateway - mcp_server_base_url: Base URL of the MCP server - project_name: Name of the project in Invariant Explorer push_to_explorer: Whether to push traces to the Invariant Explorer tool_name: Name of the tool to call tool_args: Arguments for the tool call @@ -90,12 +87,7 @@ async def run( client = MCPClient() try: await client.connect_to_sse_server( - server_url=gateway_url, - headers={ - "MCP-SERVER-BASE-URL": mcp_server_base_url, - "INVARIANT-PROJECT-NAME": project_name, - "PUSH-INVARIANT-EXPLORER": str(push_to_explorer), - }, + server_url=gateway_url, headers=headers or {} ) # list tools listed_tools = await client.session.list_tools()