Address comments on PR and update README.

This commit is contained in:
Hemang
2025-06-04 11:20:52 +02:00
committed by Hemang Sarkar
parent cc3e96c20a
commit 05e09331e9
5 changed files with 60 additions and 34 deletions

View File

@@ -279,22 +279,44 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant
This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀
### **Using MCP with Invariant Gateway**
Invariant Gateway supports MCP (both stdio and SSE transports) tool calling.
Invariant Gateway supports MCP (stdio, SSE and Streamable http) tool calling.
For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp).
For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server:
For **SSE transport based MCP**, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server while guardrailing:
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
* Use the following configuration to connect to the local Gateway instance:
```python
await client.connect_to_sse_server(
from mcp.client.sse import sse_client
await connect_to_sse_server(
server_url="http://localhost:8005/api/v1/gateway/mcp/sse",
headers={
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
"PUSH-INVARIANT-EXPLORER": "true",
"INVARIANT-API-KEY": "<your-invariant-api-key>"
"INVARIANT-X-MCP-SERVER-{CUSTOM-MCP-SERVER-HEADER-NAME}": "<custom-value-passed-to-mcp-server>"
},
)
```
For **Streamable HTTP transport based MCP**, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server while guardrailing:
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
* Use the following configuration to connect to the local Gateway instance:
```python
from mcp.client.streamable_http import streamablehttp_client
await streamablehttp_client(
url="http://localhost:8005/api/v1/gateway/mcp/sse",
headers={
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
"PUSH-INVARIANT-EXPLORER": "true",
"INVARIANT-API-KEY": "<your-invariant-api-key>"
"INVARIANT-X-MCP-SERVER-{CUSTOM-MCP-SERVER-HEADER-NAME}": "<custom-value-passed-to-mcp-server>"
},
)
```
@@ -303,6 +325,8 @@ The `INVARIANT-API-KEY` header is used both for pushing the traces to explorer a
If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.
If you pass a header called `INVARIANT-X-MCP-SERVER-CUSTOM-API-KEY`, it will be passed as the `CUSTOM-API-KEY` header to the underlying MCP server.
You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
---

View File

@@ -20,4 +20,5 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
%s
"""
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
UTF_8 = "utf-8"
UTF_8 = "utf-8"
MCP_CUSTOM_HEADER_PREFIX = "INVARIANT-X-MCP-SERVER-"

View File

@@ -274,7 +274,6 @@ class McpAttributes(BaseModel):
push_explorer: bool
explorer_dataset: str
invariant_api_key: Optional[str] = None
failure_response_format: Optional[str] = None
verbose: Optional[bool] = False
metadata: dict[str, Any] = Field(default_factory=dict)
@@ -360,7 +359,6 @@ class McpAttributes(BaseModel):
return cls(
push_explorer=config.push_explorer,
explorer_dataset=config.project_name,
failure_response_format=config.failure_response_format,
verbose=config.verbose,
metadata=metadata,
)

View File

@@ -11,7 +11,7 @@ from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import CLIENT_TIMEOUT
from gateway.mcp.constants import UTF_8
from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8
from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
@@ -123,17 +123,18 @@ class SSETransport(MCPTransportBase):
mcp_server_messages_endpoint = f"{mcp_server_base_url}/messages/?{session_id}"
# Filter headers for MCP server
mcp_headers = {
k: v
for k, v in request.headers.items()
if k.lower() in {"connection", "accept", "content-length", "content-type"}
}
filtered_headers = {}
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_POST_HEADERS:
filtered_headers[k] = v
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
response = await client.post(
url=mcp_server_messages_endpoint,
headers=mcp_headers,
headers=filtered_headers,
json=request_body,
params=dict(request.query_params),
)
@@ -155,11 +156,12 @@ class SSETransport(MCPTransportBase):
response_headers = {}
# Filter headers for SSE
filtered_headers = {
k: v
for k, v in request.headers.items()
if k.lower() in {"connection", "accept", "cache-control"}
}
filtered_headers = {}
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_SSE_HEADERS:
filtered_headers[k] = v
sse_header_attributes = McpAttributes.from_request_headers(request.headers)

View File

@@ -11,6 +11,7 @@ from fastapi.responses import StreamingResponse
from gateway.common.constants import CLIENT_TIMEOUT
from gateway.mcp.constants import (
INVARIANT_SESSION_ID_PREFIX,
MCP_CUSTOM_HEADER_PREFIX,
UTF_8,
)
from gateway.mcp.mcp_sessions_manager import (
@@ -148,11 +149,12 @@ class StreamableTransport(MCPTransportBase):
mcp_server_endpoint = self._get_mcp_server_endpoint(request)
response_headers = {}
filtered_headers = {
k: v
for k, v in request.headers.items()
if k.lower() in MCP_SERVER_GET_HEADERS
}
filtered_headers = {}
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_GET_HEADERS:
filtered_headers[k] = v
async def event_generator():
async with httpx.AsyncClient(
@@ -399,17 +401,16 @@ class StreamableTransport(MCPTransportBase):
def _get_headers_for_mcp_post_and_delete(self, request: Request) -> dict:
"""Get filtered headers for MCP server requests."""
return {
k: v
for k, v in request.headers.items()
if (
k.lower() in MCP_SERVER_POST_DELETE_HEADERS
and not (
k.lower() == MCP_SESSION_ID_HEADER
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
)
)
}
filtered_headers = {}
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not (
k.lower() == MCP_SESSION_ID_HEADER
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
):
filtered_headers[k] = v
return filtered_headers
def _get_session_id(self, request: Request) -> str:
"""Extract session ID from request headers."""