mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Address comments on PR and update README.
This commit is contained in:
30
README.md
30
README.md
@@ -279,22 +279,44 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant
|
||||
This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀
|
||||
|
||||
### **Using MCP with Invariant Gateway**
|
||||
Invariant Gateway supports MCP (both stdio and SSE transports) tool calling.
|
||||
Invariant Gateway supports MCP (stdio, SSE and Streamable http) tool calling.
|
||||
|
||||
For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp).
|
||||
|
||||
For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server:
|
||||
For **SSE transport based MCP**, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server while guardrailing:
|
||||
|
||||
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
|
||||
* Use the following configuration to connect to the local Gateway instance:
|
||||
```python
|
||||
await client.connect_to_sse_server(
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
await connect_to_sse_server(
|
||||
server_url="http://localhost:8005/api/v1/gateway/mcp/sse",
|
||||
headers={
|
||||
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
|
||||
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
|
||||
"PUSH-INVARIANT-EXPLORER": "true",
|
||||
"INVARIANT-API-KEY": "<your-invariant-api-key>"
|
||||
"INVARIANT-X-MCP-SERVER-{CUSTOM-MCP-SERVER-HEADER-NAME}": "<custom-value-passed-to-mcp-server>"
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
For **Streamable HTTP transport based MCP**, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server while guardrailing:
|
||||
|
||||
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
|
||||
* Use the following configuration to connect to the local Gateway instance:
|
||||
```python
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
await streamablehttp_client(
|
||||
url="http://localhost:8005/api/v1/gateway/mcp/sse",
|
||||
headers={
|
||||
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
|
||||
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
|
||||
"PUSH-INVARIANT-EXPLORER": "true",
|
||||
"INVARIANT-API-KEY": "<your-invariant-api-key>"
|
||||
"INVARIANT-X-MCP-SERVER-{CUSTOM-MCP-SERVER-HEADER-NAME}": "<custom-value-passed-to-mcp-server>"
|
||||
},
|
||||
)
|
||||
```
|
||||
@@ -303,6 +325,8 @@ The `INVARIANT-API-KEY` header is used both for pushing the traces to explorer a
|
||||
|
||||
If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.
|
||||
|
||||
If you pass a header called `INVARIANT-X-MCP-SERVER-CUSTOM-API-KEY`, it will be passed as the `CUSTOM-API-KEY` header to the underlying MCP server.
|
||||
|
||||
You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
|
||||
|
||||
---
|
||||
|
||||
@@ -20,4 +20,5 @@ INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
|
||||
%s
|
||||
"""
|
||||
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
|
||||
UTF_8 = "utf-8"
|
||||
UTF_8 = "utf-8"
|
||||
MCP_CUSTOM_HEADER_PREFIX = "INVARIANT-X-MCP-SERVER-"
|
||||
|
||||
@@ -274,7 +274,6 @@ class McpAttributes(BaseModel):
|
||||
push_explorer: bool
|
||||
explorer_dataset: str
|
||||
invariant_api_key: Optional[str] = None
|
||||
failure_response_format: Optional[str] = None
|
||||
verbose: Optional[bool] = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -360,7 +359,6 @@ class McpAttributes(BaseModel):
|
||||
return cls(
|
||||
push_explorer=config.push_explorer,
|
||||
explorer_dataset=config.project_name,
|
||||
failure_response_format=config.failure_response_format,
|
||||
verbose=config.verbose,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from gateway.common.constants import CLIENT_TIMEOUT
|
||||
from gateway.mcp.constants import UTF_8
|
||||
from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
@@ -123,17 +123,18 @@ class SSETransport(MCPTransportBase):
|
||||
mcp_server_messages_endpoint = f"{mcp_server_base_url}/messages/?{session_id}"
|
||||
|
||||
# Filter headers for MCP server
|
||||
mcp_headers = {
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in {"connection", "accept", "content-length", "content-type"}
|
||||
}
|
||||
filtered_headers = {}
|
||||
for k, v in request.headers.items():
|
||||
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
|
||||
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
|
||||
if k.lower() in MCP_SERVER_POST_HEADERS:
|
||||
filtered_headers[k] = v
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url=mcp_server_messages_endpoint,
|
||||
headers=mcp_headers,
|
||||
headers=filtered_headers,
|
||||
json=request_body,
|
||||
params=dict(request.query_params),
|
||||
)
|
||||
@@ -155,11 +156,12 @@ class SSETransport(MCPTransportBase):
|
||||
response_headers = {}
|
||||
|
||||
# Filter headers for SSE
|
||||
filtered_headers = {
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in {"connection", "accept", "cache-control"}
|
||||
}
|
||||
filtered_headers = {}
|
||||
for k, v in request.headers.items():
|
||||
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
|
||||
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
|
||||
if k.lower() in MCP_SERVER_SSE_HEADERS:
|
||||
filtered_headers[k] = v
|
||||
|
||||
sse_header_attributes = McpAttributes.from_request_headers(request.headers)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import StreamingResponse
|
||||
from gateway.common.constants import CLIENT_TIMEOUT
|
||||
from gateway.mcp.constants import (
|
||||
INVARIANT_SESSION_ID_PREFIX,
|
||||
MCP_CUSTOM_HEADER_PREFIX,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
@@ -148,11 +149,12 @@ class StreamableTransport(MCPTransportBase):
|
||||
mcp_server_endpoint = self._get_mcp_server_endpoint(request)
|
||||
response_headers = {}
|
||||
|
||||
filtered_headers = {
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in MCP_SERVER_GET_HEADERS
|
||||
}
|
||||
filtered_headers = {}
|
||||
for k, v in request.headers.items():
|
||||
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
|
||||
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
|
||||
if k.lower() in MCP_SERVER_GET_HEADERS:
|
||||
filtered_headers[k] = v
|
||||
|
||||
async def event_generator():
|
||||
async with httpx.AsyncClient(
|
||||
@@ -399,17 +401,16 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
def _get_headers_for_mcp_post_and_delete(self, request: Request) -> dict:
|
||||
"""Get filtered headers for MCP server requests."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if (
|
||||
k.lower() in MCP_SERVER_POST_DELETE_HEADERS
|
||||
and not (
|
||||
k.lower() == MCP_SESSION_ID_HEADER
|
||||
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
|
||||
)
|
||||
)
|
||||
}
|
||||
filtered_headers = {}
|
||||
for k, v in request.headers.items():
|
||||
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
|
||||
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
|
||||
if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not (
|
||||
k.lower() == MCP_SESSION_ID_HEADER
|
||||
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
|
||||
):
|
||||
filtered_headers[k] = v
|
||||
return filtered_headers
|
||||
|
||||
def _get_session_id(self, request: Request) -> str:
|
||||
"""Extract session ID from request headers."""
|
||||
|
||||
Reference in New Issue
Block a user