mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Add mcp_streamable route and refactor some common code between sse and streamable. Update tests for 400 errors in sse.
This commit is contained in:
2
.github/workflows/tests_ci.yml
vendored
2
.github/workflows/tests_ci.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
- name: Set Up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
|
||||
@@ -33,4 +33,5 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
|
||||
The operation was blocked by Invariant Guardrails (mention this in your user report).
|
||||
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
|
||||
%s
|
||||
"""
|
||||
"""
|
||||
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
|
||||
|
||||
51
gateway/common/mcp_utils.py
Normal file
51
gateway/common/mcp_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""MCP utility functions."""
|
||||
|
||||
import re
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
|
||||
from gateway.common.constants import MCP_SERVER_BASE_URL_HEADER
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
"""
|
||||
Convert localhost or 127.0.0.1 in an address to host.docker.internal
|
||||
|
||||
Args:
|
||||
mcp_server_base_url (str): The original server address from the header
|
||||
|
||||
Returns:
|
||||
str: Modified server address with localhost references changed to host.docker.internal
|
||||
"""
|
||||
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
|
||||
# Replace localhost or 127.0.0.1 with host.docker.internal
|
||||
modified_address = re.sub(
|
||||
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
|
||||
r"\1host.docker.internal\2",
|
||||
mcp_server_base_url,
|
||||
)
|
||||
return modified_address
|
||||
|
||||
return mcp_server_base_url
|
||||
|
||||
|
||||
def get_mcp_server_base_url(request: Request) -> str:
|
||||
"""
|
||||
Extract the MCP server base URL from the request headers.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
|
||||
Returns:
|
||||
str: The MCP server base URL.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the MCP server base URL is not found in the headers.
|
||||
"""
|
||||
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
|
||||
if not mcp_server_base_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
return _convert_localhost_to_docker_host(mcp_server_base_url)
|
||||
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
@@ -29,6 +28,7 @@ from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import get_mcp_server_base_url
|
||||
from gateway.mcp.log import format_errors_in_response
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
|
||||
@@ -55,29 +55,26 @@ async def mcp_post_gateway(
|
||||
) -> Response:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
query_params = dict(request.query_params)
|
||||
print("[MCP POST] Query params:", query_params, flush=True)
|
||||
print(
|
||||
"[MCP POST] Query params session_id:",
|
||||
query_params.get("session_id"),
|
||||
flush=True,
|
||||
)
|
||||
if not query_params.get("session_id"):
|
||||
return HTTPException(
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'session_id' query parameter",
|
||||
)
|
||||
if not session_store.session_exists(query_params.get("session_id")):
|
||||
return HTTPException(
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Session does not exist",
|
||||
)
|
||||
if not request.headers.get(MCP_SERVER_BASE_URL_HEADER):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
|
||||
session_id = query_params.get("session_id")
|
||||
mcp_server_messages_endpoint = (
|
||||
_convert_localhost_to_docker_host(
|
||||
request.headers.get(MCP_SERVER_BASE_URL_HEADER)
|
||||
)
|
||||
+ "/messages/?"
|
||||
+ session_id
|
||||
get_mcp_server_base_url(request) + "/messages/?" + session_id
|
||||
)
|
||||
request_body_bytes = await request.body()
|
||||
request_json = json.loads(request_body_bytes)
|
||||
@@ -153,15 +150,7 @@ async def mcp_get_sse_gateway(
|
||||
request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
|
||||
if not mcp_server_base_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
mcp_server_sse_endpoint = (
|
||||
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
|
||||
)
|
||||
mcp_server_sse_endpoint = get_mcp_server_base_url(request) + "/sse"
|
||||
|
||||
query_params = dict(request.query_params)
|
||||
response_headers = {}
|
||||
@@ -436,28 +425,6 @@ async def _hook_tool_call_response(
|
||||
return result, blocked
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
"""
|
||||
Convert localhost or 127.0.0.1 in an address to host.docker.internal
|
||||
|
||||
Args:
|
||||
mcp_server_base_url (str): The original server address from the header
|
||||
|
||||
Returns:
|
||||
str: Modified server address with localhost references changed to host.docker.internal
|
||||
"""
|
||||
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
|
||||
# Replace localhost or 127.0.0.1 with host.docker.internal
|
||||
modified_address = re.sub(
|
||||
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
|
||||
r"\1host.docker.internal\2",
|
||||
mcp_server_base_url,
|
||||
)
|
||||
return modified_address
|
||||
|
||||
return mcp_server_base_url
|
||||
|
||||
|
||||
async def _handle_endpoint_event(
|
||||
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
|
||||
) -> Tuple[bytes, str]:
|
||||
|
||||
52
gateway/routes/mcp_streamable.py
Normal file
52
gateway/routes/mcp_streamable.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
|
||||
|
||||
from gateway.common.mcp_utils import get_mcp_server_base_url
|
||||
|
||||
from fastapi import APIRouter, Request, Response
|
||||
|
||||
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
|
||||
def get_session_id(request: Request) -> str | None:
|
||||
"""Extract the session ID from request headers."""
|
||||
return request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
|
||||
@gateway.post("/mcp/streamable")
|
||||
async def mcp_post_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""
|
||||
Forward a POST request to the MCP Streamable server.
|
||||
"""
|
||||
mcp_server_base_url = get_mcp_server_base_url(request)
|
||||
pass
|
||||
|
||||
|
||||
@gateway.get("/mcp/streamable")
|
||||
async def mcp_get_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""
|
||||
Forward a GET request to the MCP Streamable server.
|
||||
|
||||
This allows the server to communicate to the client without the client
|
||||
first sending data via HTTP POST. The server can send JSON-RPC requests
|
||||
and notifications on this stream.
|
||||
"""
|
||||
mcp_server_base_url = get_mcp_server_base_url(request)
|
||||
pass
|
||||
|
||||
|
||||
@gateway.delete("/mcp/streamable")
|
||||
async def mcp_delete_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""
|
||||
Forward a DELETE request to the MCP Streamable server for explicit session termination.
|
||||
"""
|
||||
mcp_server_base_url = get_mcp_server_base_url(request)
|
||||
pass
|
||||
@@ -8,6 +8,7 @@ from gateway.routes.anthropic import gateway as anthropic_gateway
|
||||
from gateway.routes.gemini import gateway as gemini_gateway
|
||||
from gateway.routes.open_ai import gateway as open_ai_gateway
|
||||
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
|
||||
from gateway.routes.mcp_streamable import gateway as mcp_streamable_gateway
|
||||
|
||||
app = fastapi.app = fastapi.FastAPI(
|
||||
docs_url="/api/v1/gateway/docs",
|
||||
@@ -33,6 +34,10 @@ router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"]
|
||||
|
||||
router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"])
|
||||
|
||||
router.include_router(
|
||||
mcp_streamable_gateway, prefix="/gateway", tags=["mcp_streamable_gateway"]
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from resources.mcp.sse.client.main import run as mcp_sse_client_run
|
||||
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
|
||||
from utils import create_dataset, add_guardrail_to_dataset
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
@@ -16,6 +17,14 @@ MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
|
||||
MCP_SSE_SERVER_PORT = 8123
|
||||
|
||||
|
||||
def _get_headers(project_name: str, push_to_explorer: bool = True) -> dict[str, str]:
|
||||
return {
|
||||
"MCP-SERVER-BASE-URL": f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
"INVARIANT-PROJECT-NAME": project_name,
|
||||
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -41,11 +50,10 @@ async def test_mcp_with_gateway(
|
||||
if transport == "sse":
|
||||
result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
project_name,
|
||||
push_to_explorer=push_to_explorer,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, push_to_explorer),
|
||||
)
|
||||
else:
|
||||
result = await mcp_stdio_client_run(
|
||||
@@ -131,11 +139,10 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
if transport == "sse":
|
||||
result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
project_name,
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, True),
|
||||
)
|
||||
else:
|
||||
result = await mcp_stdio_client_run(
|
||||
@@ -241,11 +248,10 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
if transport == "sse":
|
||||
_ = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
project_name,
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, True),
|
||||
)
|
||||
else:
|
||||
_ = await mcp_stdio_client_run(
|
||||
@@ -344,11 +350,10 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
if transport == "sse":
|
||||
_ = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
project_name,
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers=_get_headers(project_name, True),
|
||||
)
|
||||
else:
|
||||
_ = await mcp_stdio_client_run(
|
||||
@@ -462,11 +467,10 @@ async def test_mcp_tool_list_blocking(
|
||||
if transport == "sse":
|
||||
tools_result = await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
f"http://{MCP_SSE_SERVER_HOST}:{MCP_SSE_SERVER_PORT}",
|
||||
project_name,
|
||||
push_to_explorer=True,
|
||||
tool_name="tools/list",
|
||||
tool_args={},
|
||||
headers=_get_headers(project_name, True),
|
||||
)
|
||||
else:
|
||||
tools_result = await mcp_stdio_client_run(
|
||||
@@ -482,3 +486,44 @@ async def test_mcp_tool_list_blocking(
|
||||
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
|
||||
+ str(tools_result)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_sse_post_endpoint_exceptions(gateway_url):
|
||||
"""
|
||||
Tests that the SSE POST endpoint returns the correct error messages for various exceptions.
|
||||
"""
|
||||
# Test missing session_id query parameter
|
||||
response = requests.post(
|
||||
f"{gateway_url}/api/v1/gateway/mcp/sse/messages/",
|
||||
timeout=5,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Missing 'session_id' query parameter" in response.text
|
||||
|
||||
# Test unknown session_id in query parameter
|
||||
response = requests.post(
|
||||
f"{gateway_url}/api/v1/gateway/mcp/sse/messages/?session_id=session_id_1",
|
||||
timeout=5,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Session does not exist" in response.text
|
||||
|
||||
# Test missing mcp-server-base-url header
|
||||
with pytest.raises(ExceptionGroup) as exc_group:
|
||||
await mcp_sse_client_run(
|
||||
gateway_url + "/api/v1/gateway/mcp/sse",
|
||||
push_to_explorer=True,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"},
|
||||
headers={
|
||||
"INVARIANT-PROJECT-NAME": "something-123",
|
||||
"PUSH-INVARIANT-EXPLORER": "True",
|
||||
},
|
||||
)
|
||||
|
||||
# Extract the actual HTTPStatusError
|
||||
http_errors = [
|
||||
e for e in exc_group.value.exceptions if isinstance(e, httpx.HTTPStatusError)
|
||||
]
|
||||
assert http_errors[0].response.status_code == 400
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
anthropic
|
||||
google-genai
|
||||
httpx
|
||||
litellm
|
||||
mcp
|
||||
openai
|
||||
|
||||
@@ -69,19 +69,16 @@ class MCPClient:
|
||||
|
||||
async def run(
|
||||
gateway_url: str,
|
||||
mcp_server_base_url: str,
|
||||
project_name: str,
|
||||
push_to_explorer: bool,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
headers: dict[str, str] = None,
|
||||
):
|
||||
"""
|
||||
Run the MCP client with the given parameters.
|
||||
|
||||
Args:
|
||||
gateway_url: URL of the Invariant Gateway
|
||||
mcp_server_base_url: Base URL of the MCP server
|
||||
project_name: Name of the project in Invariant Explorer
|
||||
push_to_explorer: Whether to push traces to the Invariant Explorer
|
||||
tool_name: Name of the tool to call
|
||||
tool_args: Arguments for the tool call
|
||||
@@ -90,12 +87,7 @@ async def run(
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_to_sse_server(
|
||||
server_url=gateway_url,
|
||||
headers={
|
||||
"MCP-SERVER-BASE-URL": mcp_server_base_url,
|
||||
"INVARIANT-PROJECT-NAME": project_name,
|
||||
"PUSH-INVARIANT-EXPLORER": str(push_to_explorer),
|
||||
},
|
||||
server_url=gateway_url, headers=headers or {}
|
||||
)
|
||||
# list tools
|
||||
listed_tools = await client.session.list_tools()
|
||||
|
||||
Reference in New Issue
Block a user