mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-02 11:21:40 +02:00
Merge branch 'main' into mcp-metadata
This commit is contained in:
@@ -24,6 +24,7 @@ This allows you to _observe and debug_ your agents in [Invariant Explorer](https
|
||||
- [x] **Single Line Setup**: Just change the base URL of your LLM provider to the Invariant Gateway.
|
||||
- [x] **Intercepts agents on an LLM-level** for better debugging and analysis.
|
||||
- [x] **Tool Calling and Computer Use Support** to capture all forms of agentic interactions.
|
||||
- [x] **MCP Protocol Support** for both standard I/O and Server-Sent Events (SSE) transports.
|
||||
- [x] **Seamless forwarding and LLM streaming** to OpenAI, Anthropic, and other LLM providers.
|
||||
- [x] **Store and organize runtime traces** in the [Invariant Explorer](https://explorer.invariantlabs.ai/).
|
||||
|
||||
@@ -277,6 +278,30 @@ export ANTHROPIC_API_KEY={your-anthropic-api-key};invariant-auth={your-invariant
|
||||
|
||||
This setup ensures that SWE-agent works seamlessly with Invariant Gateway, maintaining compatibility while enabling full functionality. 🚀
|
||||
|
||||
### **Using MCP with Invariant Gateway**
|
||||
Invariant Gateway supports MCP (both stdio and SSE transports) tool calling.
|
||||
|
||||
For stdio transport based MCP, follow steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main/gateway/mcp).
|
||||
|
||||
For SSE transport based MCP, here are the steps to point your MCP client to a local instance of the Invariant Gateway which will then proxy all calls to the MCP server:
|
||||
|
||||
* Run the Gateway locally by following the steps [here](https://github.com/invariantlabs-ai/invariant-gateway/tree/main?tab=readme-ov-file#run-the-gateway-locally).
|
||||
* Use the following configuration to connect to the local Gateway instance:
|
||||
```python
|
||||
await client.connect_to_sse_server(
|
||||
server_url="http://localhost:8005/api/v1/gateway/mcp/sse",
|
||||
headers={
|
||||
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
|
||||
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
|
||||
"PUSH-INVARIANT-EXPLORER": "true",
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.
|
||||
|
||||
You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
|
||||
|
||||
---
|
||||
|
||||
## **Run the Gateway Locally**
|
||||
|
||||
@@ -13,3 +13,23 @@ IGNORED_HEADERS = [
|
||||
]
|
||||
|
||||
CLIENT_TIMEOUT = 60.0
|
||||
|
||||
# MCP related constants
|
||||
MCP_METHOD = "method"
|
||||
MCP_TOOL_CALL = "tools/call"
|
||||
MCP_LIST_TOOLS = "tools/list"
|
||||
MCP_PARAMS = "params"
|
||||
MCP_RESULT = "result"
|
||||
MCP_SERVER_INFO = "serverInfo"
|
||||
MCP_CLIENT_INFO = "clientInfo"
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """
|
||||
[Invariant Guardrails] The MCP tool call was blocked for security reasons.
|
||||
Do not attempt to circumvent this block, rather explain to the user based
|
||||
on the following output what went wrong: %s
|
||||
"""
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
|
||||
[Invariant Guardrails] This server was blocked from advertising its tools due to a security guardrail failure.
|
||||
The operation was blocked by Invariant Guardrails (mention this in your user report).
|
||||
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
|
||||
%s
|
||||
"""
|
||||
@@ -0,0 +1,324 @@
|
||||
"""MCP Sessions Manager related classes"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from gateway.common.guardrails import GuardrailRuleSet, GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.integrations.explorer import (
|
||||
create_annotations_from_guardrails_errors,
|
||||
fetch_guardrails_from_explorer,
|
||||
)
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
class McpSession(BaseModel):
|
||||
"""
|
||||
Represents a single MCP session.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
id_to_method_mapping: Dict[int, str] = Field(default_factory=dict)
|
||||
explorer_dataset: str
|
||||
push_explorer: bool
|
||||
trace_id: Optional[str] = None
|
||||
last_trace_length: int = 0
|
||||
annotations: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
guardrails: GuardrailRuleSet = Field(
|
||||
default_factory=lambda: GuardrailRuleSet(
|
||||
blocking_guardrails=[], logging_guardrails=[]
|
||||
)
|
||||
)
|
||||
# When tool calls are blocked, the error message is stored here
|
||||
# and sent to the client via the SSE stream.
|
||||
pending_error_messages: List[dict] = Field(default_factory=list)
|
||||
|
||||
# Lock to maintain in-order pushes to explorer
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
async def load_guardrails(self) -> None:
|
||||
"""
|
||||
Load guardrails for the session.
|
||||
|
||||
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
|
||||
"""
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.explorer_dataset,
|
||||
"Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
def _deduplicate_annotations(self, new_annotations: list) -> list:
|
||||
"""Deduplicate new_annotations using the annotations in the session."""
|
||||
deduped_annotations = []
|
||||
for annotation in new_annotations:
|
||||
# Check if an annotation with the same content and address exists in self.annotations
|
||||
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
|
||||
# to remove duplicates instead of using a custom logic.
|
||||
# This is a temporary solution until the Invariant SDK is updated.
|
||||
is_duplicate = False
|
||||
for current_annotation in self.annotations:
|
||||
if (
|
||||
annotation.content == current_annotation.content
|
||||
and annotation.address == current_annotation.address
|
||||
and annotation.extra_metadata == current_annotation.extra_metadata
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
deduped_annotations.append(annotation)
|
||||
|
||||
return deduped_annotations
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def session_lock(self):
|
||||
"""
|
||||
Context manager for the session lock.
|
||||
|
||||
Usage:
|
||||
async with session.session_lock():
|
||||
# Code that requires exclusive access to the session
|
||||
"""
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
self,
|
||||
message: dict,
|
||||
action: GuardrailAction = GuardrailAction.BLOCK,
|
||||
) -> dict:
|
||||
"""
|
||||
Check against guardrails of type action.
|
||||
"""
|
||||
# Skip if no guardrails are configured for this action
|
||||
if not (
|
||||
(self.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
|
||||
or (self.guardrails.logging_guardrails and action == GuardrailAction.LOG)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=self.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=self.guardrails,
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
self.guardrails.blocking_guardrails
|
||||
if action == GuardrailAction.BLOCK
|
||||
else self.guardrails.logging_guardrails
|
||||
)
|
||||
|
||||
result = await check_guardrails(
|
||||
messages=self.messages + [message],
|
||||
guardrails=guardrails_to_check,
|
||||
context=context,
|
||||
)
|
||||
return result
|
||||
|
||||
async def add_message(
|
||||
self, message: Dict[str, Any], guardrails_result=Dict
|
||||
) -> None:
|
||||
"""
|
||||
Add a message to the session and optionally push to explorer.
|
||||
|
||||
Args:
|
||||
message: The message to add
|
||||
"""
|
||||
async with self.session_lock():
|
||||
annotations = []
|
||||
if guardrails_result and guardrails_result.get("errors", []):
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors")
|
||||
)
|
||||
|
||||
if self.guardrails.logging_guardrails:
|
||||
logging_guardrails_check_result = (
|
||||
await self.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.LOG
|
||||
)
|
||||
)
|
||||
if (
|
||||
logging_guardrails_check_result
|
||||
and logging_guardrails_check_result.get("errors", [])
|
||||
):
|
||||
annotations.extend(
|
||||
create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_check_result["errors"]
|
||||
)
|
||||
)
|
||||
deduplicated_annotations = self._deduplicate_annotations(annotations)
|
||||
# pylint: disable=no-member
|
||||
self.messages.append(message)
|
||||
# If push_explorer is enabled, push the trace
|
||||
if self.push_explorer:
|
||||
await self._push_trace_update(deduplicated_annotations)
|
||||
|
||||
async def _push_trace_update(self, deduplicated_annotations: list) -> None:
|
||||
"""
|
||||
Push trace updates to the explorer.
|
||||
|
||||
If a trace doesn't exist, create a new one. If it does, append new messages.
|
||||
|
||||
This is an internal method that should only be called within a lock.
|
||||
"""
|
||||
try:
|
||||
client = AsyncClient(
|
||||
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
|
||||
)
|
||||
|
||||
# If no trace exists, create a new one
|
||||
if not self.trace_id:
|
||||
# pylint: disable=no-member
|
||||
metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])}
|
||||
if self.metadata.get("mcp_client_name"):
|
||||
metadata["mcp_client"] = self.metadata.get("mcp_client_name")
|
||||
if self.metadata.get("mcp_server_name"):
|
||||
metadata["mcp_server"] = self.metadata.get("mcp_server_name")
|
||||
|
||||
response = await client.push_trace(
|
||||
PushTracesRequest(
|
||||
messages=[self.messages],
|
||||
dataset=self.explorer_dataset,
|
||||
metadata=[metadata],
|
||||
annotations=[deduplicated_annotations],
|
||||
)
|
||||
)
|
||||
self.trace_id = response.id[0]
|
||||
else:
|
||||
new_messages = self.messages[self.last_trace_length :]
|
||||
if new_messages:
|
||||
await client.append_messages(
|
||||
AppendMessagesRequest(
|
||||
trace_id=self.trace_id,
|
||||
messages=new_messages,
|
||||
annotations=deduplicated_annotations,
|
||||
)
|
||||
)
|
||||
# pylint: disable=no-member
|
||||
self.annotations.extend(deduplicated_annotations)
|
||||
self.last_trace_length = len(self.messages)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
|
||||
|
||||
async def add_pending_error_message(self, error_message: dict) -> None:
|
||||
"""
|
||||
Add a pending error message to the session.
|
||||
|
||||
Args:
|
||||
error_message: The error message to add
|
||||
"""
|
||||
async with self.session_lock():
|
||||
# pylint: disable=no-member
|
||||
self.pending_error_messages.append(error_message)
|
||||
|
||||
async def get_pending_error_messages(self) -> List[dict]:
|
||||
"""
|
||||
Get all pending error messages for the session.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of pending error messages
|
||||
"""
|
||||
async with self.session_lock():
|
||||
messages = list(self.pending_error_messages)
|
||||
self.pending_error_messages = []
|
||||
return messages
|
||||
|
||||
|
||||
class SseHeaderAttributes(BaseModel):
|
||||
"""
|
||||
A Pydantic model to represent header attributes.
|
||||
"""
|
||||
|
||||
push_explorer: bool
|
||||
explorer_dataset: str
|
||||
|
||||
@classmethod
|
||||
def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
|
||||
"""
|
||||
Create an instance from FastAPI request headers.
|
||||
|
||||
Args:
|
||||
headers: FastAPI Request headers
|
||||
|
||||
Returns:
|
||||
SseHeaderAttributes: An instance with values extracted from headers
|
||||
"""
|
||||
# Extract and process header values
|
||||
project_name = headers.get("INVARIANT-PROJECT-NAME")
|
||||
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
|
||||
|
||||
# Determine explorer_dataset
|
||||
if project_name:
|
||||
explorer_dataset = project_name
|
||||
else:
|
||||
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
|
||||
|
||||
# Determine push_explorer
|
||||
push_explorer = push_explorer_header == "true"
|
||||
|
||||
# Create and return instance
|
||||
return cls(push_explorer=push_explorer, explorer_dataset=explorer_dataset)
|
||||
|
||||
|
||||
class McpSessionsManager:
|
||||
"""
|
||||
A class to manage MCP sessions and their messages.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, McpSession] = {}
|
||||
|
||||
def session_exists(self, session_id: str) -> bool:
|
||||
"""Check if a session exists"""
|
||||
return session_id in self._sessions
|
||||
|
||||
async def initialize_session(
|
||||
self, session_id: str, sse_header_attributes: SseHeaderAttributes
|
||||
) -> None:
|
||||
"""Initialize a new session"""
|
||||
if session_id not in self._sessions:
|
||||
session = McpSession(
|
||||
session_id=session_id,
|
||||
explorer_dataset=sse_header_attributes.explorer_dataset,
|
||||
push_explorer=sse_header_attributes.push_explorer,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
# Load guardrails for the session from the explorer
|
||||
await session.load_guardrails()
|
||||
|
||||
def get_session(self, session_id: str) -> McpSession:
|
||||
"""Get a session by ID"""
|
||||
if session_id not in self._sessions:
|
||||
raise ValueError(f"Session {session_id} does not exist.")
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
async def add_message_to_session(
|
||||
self, session_id: str, message: Dict[str, Any], guardrails_result: dict
|
||||
) -> None:
|
||||
"""
|
||||
Add a message to a session and push to explorer if enabled.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
message: The message to add
|
||||
guardrails_result: The result of the guardrails check
|
||||
"""
|
||||
session = self.get_session(session_id)
|
||||
await session.add_message(message, guardrails_result)
|
||||
+7
-20
@@ -11,6 +11,13 @@ from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
|
||||
from gateway.common.constants import (
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
@@ -18,30 +25,10 @@ from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
|
||||
|
||||
import getpass
|
||||
import socket
|
||||
|
||||
MCP_METHOD = "method"
|
||||
UTF_8_ENCODING = "utf-8"
|
||||
MCP_TOOL_CALL = "tools/call"
|
||||
MCP_LIST_TOOLS = "tools/list"
|
||||
MCP_INITIALIZE = "initialize"
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """
|
||||
[Security Failure] The MCP tool call was blocked for security reasons.
|
||||
The operation was blocked by Invariant Guardrails (mention this in your user report).
|
||||
|
||||
Do not attempt to circumvent this block, rather explain to the user based
|
||||
on the following output what went wrong: %s
|
||||
""".strip()
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE = """
|
||||
[Security Failure] This server was blocked from advertising its tools due to a security guardrail failure.
|
||||
|
||||
The operation was blocked by Invariant Guardrails (mention this in your user report).
|
||||
|
||||
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
|
||||
%s
|
||||
""".strip()
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,527 @@
|
||||
"""Gateway service to forward requests to the MCP SSE servers"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
from httpx_sse import aconnect_sse, ServerSentEvent
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
MCP_PARAMS,
|
||||
MCP_RESULT,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_CLIENT_INFO,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
)
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
"content-length",
|
||||
"content-type",
|
||||
}
|
||||
MCP_SERVER_SSE_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
"cache-control",
|
||||
}
|
||||
|
||||
gateway = APIRouter()
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
|
||||
@gateway.post("/mcp/sse/messages/")
|
||||
async def mcp_post_gateway(
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
query_params = dict(request.query_params)
|
||||
if not query_params.get("session_id"):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'session_id' query parameter",
|
||||
)
|
||||
if not session_store.session_exists(query_params.get("session_id")):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Session does not exist",
|
||||
)
|
||||
if not request.headers.get("mcp-server-base-url"):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
)
|
||||
|
||||
session_id = query_params.get("session_id")
|
||||
mcp_server_messages_endpoint = (
|
||||
_convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url"))
|
||||
+ "/messages/?"
|
||||
+ session_id
|
||||
)
|
||||
request_body_bytes = await request.body()
|
||||
request_json = json.loads(request_body_bytes)
|
||||
session = session_store.get_session(session_id)
|
||||
if request_json.get(MCP_METHOD) and request_json.get("id"):
|
||||
session.id_to_method_mapping[request_json.get("id")] = request_json.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.metadata["mcp_client_name"] = (
|
||||
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
if request_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Intercept and potentially block the request
|
||||
hook_tool_call_result, is_blocked = await _hook_tool_call(
|
||||
session_id=session_id, request_json=request_json
|
||||
)
|
||||
if is_blocked:
|
||||
# Add the error message to the session.
|
||||
# The error message is sent back to the client using the SSE stream.
|
||||
await session.add_pending_error_message(hook_tool_call_result)
|
||||
return Response(content="Accepted", status_code=202)
|
||||
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
url=mcp_server_messages_endpoint,
|
||||
headers={
|
||||
k: v
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in MCP_SERVER_POST_HEADERS
|
||||
},
|
||||
json=request_json,
|
||||
params=query_params,
|
||||
)
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
"X-Proxied-By": "mcp-gateway",
|
||||
**response.headers,
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
print(f"[MCP POST] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
except Exception as e:
|
||||
print(f"[MCP POST] Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
|
||||
@gateway.get("/mcp/sse")
|
||||
async def mcp_get_sse_gateway(
|
||||
request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
mcp_server_base_url = request.headers.get("mcp-server-base-url")
|
||||
if not mcp_server_base_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
)
|
||||
mcp_server_sse_endpoint = (
|
||||
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
|
||||
)
|
||||
|
||||
query_params = dict(request.query_params)
|
||||
response_headers = {}
|
||||
filtered_headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
|
||||
}
|
||||
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
|
||||
|
||||
async def event_generator():
|
||||
"""
|
||||
Generate a merged stream of MCP server events and pending error messages.
|
||||
The pending error messages are added in the POST messages handler.
|
||||
This function runs in a loop, yielding events as they arrive.
|
||||
"""
|
||||
mcp_server_events_queue = asyncio.Queue()
|
||||
pending_error_messages_queue = asyncio.Queue()
|
||||
tasks = set()
|
||||
session_id = None
|
||||
|
||||
try:
|
||||
# MCP Server Events Processor
|
||||
async def process_mcp_server_events():
|
||||
"""Connect to MCP server and process its events."""
|
||||
nonlocal session_id
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(CLIENT_TIMEOUT)
|
||||
) as client:
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
mcp_server_sse_endpoint,
|
||||
headers=filtered_headers,
|
||||
params=query_params,
|
||||
) as event_source:
|
||||
if event_source.response.status_code != 200:
|
||||
error_content = await event_source.response.aread()
|
||||
raise HTTPException(
|
||||
status_code=event_source.response.status_code,
|
||||
detail=error_content,
|
||||
)
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.event == "endpoint":
|
||||
(
|
||||
event_bytes,
|
||||
extracted_id,
|
||||
) = await _handle_endpoint_event(
|
||||
sse, sse_header_attributes
|
||||
)
|
||||
session_id = extracted_id
|
||||
|
||||
if (
|
||||
session_id
|
||||
and "process_error_messages_task"
|
||||
not in locals()
|
||||
):
|
||||
process_error_messages_task = (
|
||||
asyncio.create_task(
|
||||
_check_for_pending_error_messages(
|
||||
session_id,
|
||||
pending_error_messages_queue,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.add(process_error_messages_task)
|
||||
process_error_messages_task.add_done_callback(
|
||||
tasks.discard
|
||||
)
|
||||
|
||||
elif sse.event == "message" and session_id:
|
||||
# Process message event
|
||||
event_bytes = await _handle_message_event(
|
||||
session_id, sse
|
||||
)
|
||||
else:
|
||||
# Pass through other event types
|
||||
# pylint: disable=line-too-long
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Put the processed event in the queue
|
||||
await mcp_server_events_queue.put(event_bytes)
|
||||
|
||||
except httpx.StreamClosed as e:
|
||||
print(f"Server stream closed: {e}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error processing server events: {e}", flush=True)
|
||||
|
||||
# Start server events processor
|
||||
mcp_server_events_task = asyncio.create_task(process_mcp_server_events())
|
||||
tasks.add(mcp_server_events_task)
|
||||
mcp_server_events_task.add_done_callback(tasks.discard)
|
||||
|
||||
# Main event loop: merge MCP server events and pending error messages
|
||||
while True:
|
||||
# Create futures for both queues
|
||||
mcp_server_event_future = asyncio.create_task(
|
||||
mcp_server_events_queue.get()
|
||||
)
|
||||
pending_error_message_future = asyncio.create_task(
|
||||
pending_error_messages_queue.get()
|
||||
)
|
||||
|
||||
# Wait for either queue to have an item, with timeout
|
||||
done, pending = await asyncio.wait(
|
||||
[mcp_server_event_future, pending_error_message_future],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=0.25,
|
||||
)
|
||||
|
||||
for future in pending:
|
||||
future.cancel()
|
||||
|
||||
# Timeout occurred and no future completed.
|
||||
if not done:
|
||||
continue
|
||||
|
||||
for future in done:
|
||||
try:
|
||||
event = await future
|
||||
yield event
|
||||
except asyncio.CancelledError:
|
||||
# Future was cancelled, continue
|
||||
continue
|
||||
|
||||
finally:
|
||||
# Clean up all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, timeout=2)
|
||||
|
||||
# Return the streaming response
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
|
||||
)
|
||||
|
||||
|
||||
async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Hook to process the request JSON before sending it to the MCP server.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
request_json (dict): The request JSON to be processed.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request_json.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request_json.get(MCP_PARAMS).get("name"),
|
||||
"arguments": request_json.get(MCP_PARAMS).get("arguments"),
|
||||
},
|
||||
}
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
# Add the trace to the explorer
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
guardrails_result=guardrails_result,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}, True
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return request_json, False
|
||||
|
||||
|
||||
async def _hook_tool_call_response(session_id: str, response_json: dict) -> dict:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
response_json (dict): The response JSON to be processed.
|
||||
Returns:
|
||||
dict: The response JSON is returned if no guardrail is violated
|
||||
else an error dict is returned.
|
||||
"""
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
"content": response_json.get(MCP_RESULT).get("content"),
|
||||
"error": response_json.get(MCP_RESULT).get("error"),
|
||||
}
|
||||
result = response_json
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
"""
|
||||
Convert localhost or 127.0.0.1 in an address to host.docker.internal
|
||||
|
||||
Args:
|
||||
mcp_server_base_url (str): The original server address from the header
|
||||
|
||||
Returns:
|
||||
str: Modified server address with localhost references changed to host.docker.internal
|
||||
"""
|
||||
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
|
||||
# Replace localhost or 127.0.0.1 with host.docker.internal
|
||||
modified_address = re.sub(
|
||||
r"(https?://)(?:localhost|127\.0\.0\.1)(\b|:)",
|
||||
r"\1host.docker.internal\2",
|
||||
mcp_server_base_url,
|
||||
)
|
||||
return modified_address
|
||||
|
||||
return mcp_server_base_url
|
||||
|
||||
|
||||
async def _handle_endpoint_event(
|
||||
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
|
||||
) -> Tuple[bytes, str]:
|
||||
"""
|
||||
Handle the endpoint event type and modify the data accordingly.
|
||||
For endpoint events, we need to rewrite the endpoint to use our gateway.
|
||||
|
||||
Args:
|
||||
sse (ServerSentEvent): The original SSE object.
|
||||
sse_header_attributes (SseHeaderAttributes): The header attributes from the request.
|
||||
|
||||
Returns:
|
||||
bytes: Modified SSE data as bytes.
|
||||
str: session_id extracted from the data.
|
||||
"""
|
||||
# Extract session_id
|
||||
match = re.search(r"session_id=([^&\s]+)", sse.data)
|
||||
if match:
|
||||
session_id = match.group(1)
|
||||
# Initialize this session in our store if needed
|
||||
if not session_store.session_exists(session_id):
|
||||
await session_store.initialize_session(session_id, sse_header_attributes)
|
||||
|
||||
# Rewrite the endpoint to use our gateway
|
||||
modified_data = sse.data.replace(
|
||||
"/messages/?session_id=",
|
||||
"/api/v1/gateway/mcp/sse/messages/?session_id=",
|
||||
)
|
||||
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8")
|
||||
return event_bytes, session_id
|
||||
|
||||
|
||||
async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
"""
|
||||
Handle the message event type.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
sse (ServerSentEvent): The original SSE object.
|
||||
"""
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
|
||||
session = session_store.get_session(session_id)
|
||||
try:
|
||||
response_json = json.loads(sse.data)
|
||||
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.metadata["mcp_server_name"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
if method == MCP_TOOL_CALL:
|
||||
hook_tool_call_response = await _hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
response_json=response_json,
|
||||
)
|
||||
# Update the event bytes with hook_tool_call_response.
|
||||
# hook_tool_call_response is same as response_json if no guardrail is violated.
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
session_store.get_session(session_id).metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"[MCP SSE] Error parsing message JSON: {e}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(
|
||||
f"[MCP SSE] Error processing message: {e}",
|
||||
flush=True,
|
||||
)
|
||||
return event_bytes
|
||||
|
||||
|
||||
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
session = session_store.get_session(session_id)
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _check_for_pending_error_messages(
|
||||
session_id: str, pending_error_messages_queue: asyncio.Queue
|
||||
):
|
||||
"""Periodically check for and enqueue pending error messages."""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
session = session_store.get_session(session_id)
|
||||
error_messages = await session.get_pending_error_messages()
|
||||
|
||||
for error_message in error_messages:
|
||||
error_bytes = (
|
||||
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
await pending_error_messages_queue.put(error_bytes)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Error checking for messages: {e}", flush=True)
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled, exit gracefully
|
||||
return
|
||||
@@ -7,6 +7,7 @@ from starlette_compress import CompressMiddleware
|
||||
from gateway.routes.anthropic import gateway as anthropic_gateway
|
||||
from gateway.routes.gemini import gateway as gemini_gateway
|
||||
from gateway.routes.open_ai import gateway as open_ai_gateway
|
||||
from gateway.routes.mcp_sse import gateway as mcp_sse_gateway
|
||||
|
||||
app = fastapi.app = fastapi.FastAPI(
|
||||
docs_url="/api/v1/gateway/docs",
|
||||
@@ -30,6 +31,8 @@ router.include_router(anthropic_gateway, prefix="/gateway", tags=["anthropic_gat
|
||||
|
||||
router.include_router(gemini_gateway, prefix="/gateway", tags=["gemini_gateway"])
|
||||
|
||||
router.include_router(mcp_sse_gateway, prefix="/gateway", tags=["mcp_sse_gateway"])
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi==0.115.7",
|
||||
"httpx==0.28.1",
|
||||
"httpx-sse==0.4.0",
|
||||
"invariant-sdk>=0.0.11",
|
||||
"starlette-compress==1.4.0",
|
||||
"uvicorn==0.34.0"
|
||||
|
||||
@@ -195,7 +195,6 @@ async def test_response_with_tool_call(explorer_api_url, gateway_url, push_to_ex
|
||||
|
||||
assert response[1].role == "assistant"
|
||||
assert response[1].stop_reason == "end_turn"
|
||||
assert city in response[1].content[0].text.lower()
|
||||
responses.append(response)
|
||||
|
||||
if push_to_explorer:
|
||||
|
||||
@@ -123,7 +123,9 @@ async def test_chat_completion_with_tool_call_without_streaming(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
assert trace["messages"] == expected_messages
|
||||
assert trace["messages"][:2] == expected_messages[:2]
|
||||
assert "15°C" in trace["messages"][2]["content"]
|
||||
assert trace["messages"][2]["role"] == "tool"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@@ -230,4 +232,6 @@ async def test_chat_completion_with_tool_call_with_streaming(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"] = json.loads(
|
||||
expected_messages[1]["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
assert trace["messages"] == expected_messages
|
||||
assert trace["messages"][:2] == expected_messages[:2]
|
||||
assert "15°C" in trace["messages"][2]["content"]
|
||||
assert trace["messages"][2]["role"] == "tool"
|
||||
|
||||
Reference in New Issue
Block a user