From 05e09331e98236056e2fbcc2c9068ccef41fcb68 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 4 Jun 2025 11:20:52 +0200 Subject: [PATCH] Address comments on PR and update README. --- README.md | 30 +++++++++++++++++++++++--- gateway/mcp/constants.py | 3 ++- gateway/mcp/mcp_sessions_manager.py | 2 -- gateway/mcp/sse.py | 26 ++++++++++++----------- gateway/mcp/streamable.py | 33 +++++++++++++++-------------- 5 files changed, 60 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 2ccf0dc..e9cfbac 100644 --- a/README.md +++ b/README.md @@ -279,22 +279,44 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀 ### **Using MCP with Invariant Gateway** -Invariant Gateway supports MCP (both stdio and SSE transports) tool calling. +Invariant Gateway supports MCP (stdio, SSE and Streamable http) tool calling. For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp). -For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server: +For **SSE transport based MCP**, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server while guardrailing: * Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally). * Use the following configuration to connect to the local Gateway instance: ```python -await client.connect_to_sse_server( +from mcp.client.sse import sse_client + +await connect_to_sse_server( server_url="http://localhost:8005/api/v1/gateway/mcp/sse", headers={ "MCP-SERVER-BASE-URL": "", "INVARIANT-PROJECT-NAME": "", "PUSH-INVARIANT-EXPLORER": "true", "INVARIANT-API-KEY": "" + "INVARIANT-X-MCP-SERVER-{CUSTOM-MCP-SERVER-HEADER-NAME}": "" + }, + ) +``` + +For **Streamable HTTP transport based MCP**, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server while guardrailing: + +* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally). +* Use the following configuration to connect to the local Gateway instance: +```python +from mcp.client.streamable_http import streamablehttp_client + +await streamablehttp_client( + url="http://localhost:8005/api/v1/gateway/mcp/sse", + headers={ + "MCP-SERVER-BASE-URL": "", + "INVARIANT-PROJECT-NAME": "", + "PUSH-INVARIANT-EXPLORER": "true", + "INVARIANT-API-KEY": "" + "INVARIANT-X-MCP-SERVER-{CUSTOM-MCP-SERVER-HEADER-NAME}": "" }, ) ``` @@ -303,6 +325,8 @@ The `INVARIANT-API-KEY` header is used both for pushing the traces to explorer a If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there. +If you pass a header called `INVARIANT-X-MCP-SERVER-CUSTOM-API-KEY`, it will be passed as the `CUSTOM-API-KEY` header to the underlying MCP server. + You can also specify blocking or logging guardrails for the project name by visiting the Explorer. --- diff --git a/gateway/mcp/constants.py b/gateway/mcp/constants.py index a783d7f..9f6c1ad 100644 --- a/gateway/mcp/constants.py +++ b/gateway/mcp/constants.py @@ -20,4 +20,5 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """ %s """ MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url" -UTF_8 = "utf-8" \ No newline at end of file +UTF_8 = "utf-8" +MCP_CUSTOM_HEADER_PREFIX = "INVARIANT-X-MCP-SERVER-" diff --git a/gateway/mcp/mcp_sessions_manager.py b/gateway/mcp/mcp_sessions_manager.py index 2cf88ae..bb915d2 100644 --- a/gateway/mcp/mcp_sessions_manager.py +++ b/gateway/mcp/mcp_sessions_manager.py @@ -274,7 +274,6 @@ class McpAttributes(BaseModel): push_explorer: bool explorer_dataset: str invariant_api_key: Optional[str] = None - failure_response_format: Optional[str] = None verbose: Optional[bool] = False metadata: dict[str, Any] = Field(default_factory=dict) @@ -360,7 +359,6 @@ class McpAttributes(BaseModel): return cls( push_explorer=config.push_explorer, explorer_dataset=config.project_name, - failure_response_format=config.failure_response_format, verbose=config.verbose, metadata=metadata, ) diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 7a8fa5e..3109175 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -11,7 +11,7 @@ from fastapi import APIRouter, HTTPException, Request, Response from fastapi.responses import StreamingResponse from gateway.common.constants import CLIENT_TIMEOUT -from gateway.mcp.constants import UTF_8 +from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8 from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, @@ -123,17 +123,18 @@ class SSETransport(MCPTransportBase): mcp_server_messages_endpoint = f"{mcp_server_base_url}/messages/?{session_id}" # Filter headers for MCP server - mcp_headers = { - k: v - for k, v in request.headers.items() - if k.lower() in {"connection", "accept", "content-length", "content-type"} - } + filtered_headers = {} + for k, v in request.headers.items(): + if k.startswith(MCP_CUSTOM_HEADER_PREFIX): + filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v + if k.lower() in MCP_SERVER_POST_HEADERS: + filtered_headers[k] = v async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: try: response = await client.post( url=mcp_server_messages_endpoint, - headers=mcp_headers, + headers=filtered_headers, json=request_body, params=dict(request.query_params), ) @@ -155,11 +156,12 @@ class SSETransport(MCPTransportBase): response_headers = {} # Filter headers for SSE - filtered_headers = { - k: v - for k, v in request.headers.items() - if k.lower() in {"connection", "accept", "cache-control"} - } + filtered_headers = {} + for k, v in request.headers.items(): + if k.startswith(MCP_CUSTOM_HEADER_PREFIX): + filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v + if k.lower() in MCP_SERVER_SSE_HEADERS: + filtered_headers[k] = v sse_header_attributes = McpAttributes.from_request_headers(request.headers) diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index aa89c70..39a2b01 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -11,6 +11,7 @@ from fastapi.responses import StreamingResponse from gateway.common.constants import CLIENT_TIMEOUT from gateway.mcp.constants import ( INVARIANT_SESSION_ID_PREFIX, + MCP_CUSTOM_HEADER_PREFIX, UTF_8, ) from gateway.mcp.mcp_sessions_manager import ( @@ -148,11 +149,12 @@ class StreamableTransport(MCPTransportBase): mcp_server_endpoint = self._get_mcp_server_endpoint(request) response_headers = {} - filtered_headers = { - k: v - for k, v in request.headers.items() - if k.lower() in MCP_SERVER_GET_HEADERS - } + filtered_headers = {} + for k, v in request.headers.items(): + if k.startswith(MCP_CUSTOM_HEADER_PREFIX): + filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v + if k.lower() in MCP_SERVER_GET_HEADERS: + filtered_headers[k] = v async def event_generator(): async with httpx.AsyncClient( @@ -399,17 +401,16 @@ class StreamableTransport(MCPTransportBase): def _get_headers_for_mcp_post_and_delete(self, request: Request) -> dict: """Get filtered headers for MCP server requests.""" - return { - k: v - for k, v in request.headers.items() - if ( - k.lower() in MCP_SERVER_POST_DELETE_HEADERS - and not ( - k.lower() == MCP_SESSION_ID_HEADER - and v.startswith(INVARIANT_SESSION_ID_PREFIX) - ) - ) - } + filtered_headers = {} + for k, v in request.headers.items(): + if k.startswith(MCP_CUSTOM_HEADER_PREFIX): + filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v + if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not ( + k.lower() == MCP_SESSION_ID_HEADER + and v.startswith(INVARIANT_SESSION_ID_PREFIX) + ): + filtered_headers[k] = v + return filtered_headers def _get_session_id(self, request: Request) -> str: """Extract session ID from request headers."""