Refactor stdio implementation to use McpSession class.

This commit is contained in:
Hemang
2025-06-03 11:55:36 +02:00
committed by Hemang Sarkar
parent 6849fc7daa
commit e8106776b4
10 changed files with 290 additions and 760 deletions

View File

@@ -1,5 +1,7 @@
"""Common constants used in the gateway."""
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
IGNORED_HEADERS = [
"accept-encoding",
"host",

View File

@@ -16,7 +16,7 @@ from invariant_sdk.types.push_traces import PushTracesRequest
from pydantic import BaseModel, Field, PrivateAttr
from starlette.datastructures import Headers
from gateway.common.constants import INVARIANT_SESSION_ID_PREFIX
from gateway.common.constants import DEFAULT_API_URL, INVARIANT_SESSION_ID_PREFIX
from gateway.common.guardrails import GuardrailRuleSet, GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import (
@@ -25,8 +25,6 @@ from gateway.integrations.explorer import (
)
from gateway.integrations.guardrails import check_guardrails
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
def user_and_host() -> str:
"""Get the current user and hostname."""
@@ -61,30 +59,15 @@ class McpSession(BaseModel):
# and other session-related operations
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def get_invariant_api_key(self) -> str:
def _get_invariant_api_key(self) -> str:
"""Get the Invariant API key for the session."""
if self.attributes.invariant_api_key:
return self.attributes.invariant_api_key
return os.getenv("INVARIANT_API_KEY")
def get_invariant_authorization(self) -> str:
def _get_invariant_authorization(self) -> str:
"""Get the Invariant authorization header for the session."""
return "Bearer " + self.get_invariant_api_key()
async def load_guardrails(self) -> None:
"""
Load guardrails for the session.
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
"""
print("Inside load_guardrails attributes: ", self.attributes, flush=True)
self.guardrails = await fetch_guardrails_from_explorer(
self.attributes.explorer_dataset,
self.get_invariant_authorization(),
# pylint: disable=no-member
self.attributes.metadata.get("mcp_client"),
self.attributes.metadata.get("mcp_server"),
)
return "Bearer " + self._get_invariant_api_key()
def _deduplicate_annotations(self, new_annotations: list) -> list:
"""Deduplicate new_annotations using the annotations in the session."""
@@ -94,6 +77,20 @@ class McpSession(BaseModel):
deduped_annotations.append(annotation)
return deduped_annotations
async def load_guardrails(self) -> None:
"""
Load guardrails for the session.
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
"""
self.guardrails = await fetch_guardrails_from_explorer(
self.attributes.explorer_dataset,
self._get_invariant_authorization(),
# pylint: disable=no-member
self.attributes.metadata.get("mcp_client"),
self.attributes.metadata.get("mcp_server"),
)
@contextlib.asynccontextmanager
async def session_lock(self):
"""
@@ -134,15 +131,10 @@ class McpSession(BaseModel):
return {}
# Prepare context and select appropriate guardrails
print(
"Inside get_guardrails_check_result attributes: ",
self.attributes,
flush=True,
)
context = RequestContext.create(
request_json={},
dataset_name=self.attributes.explorer_dataset,
invariant_authorization=self.get_invariant_authorization(),
invariant_authorization=self._get_invariant_authorization(),
guardrails=self.guardrails,
guardrails_parameters={
"metadata": self.session_metadata(),
@@ -210,11 +202,10 @@ class McpSession(BaseModel):
This is an internal method that should only be called within a lock.
"""
print("Inside _push_trace_update attributes: ", self.attributes, flush=True)
try:
client = AsyncClient(
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
api_key=self.get_invariant_api_key(),
api_key=self._get_invariant_api_key(),
)
# If no trace exists, create a new one
@@ -247,7 +238,7 @@ class McpSession(BaseModel):
self.annotations.extend(deduplicated_annotations)
self.last_trace_length = len(self.messages)
except Exception as e: # pylint: disable=broad-except
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
print(f"[MCP] Error pushing trace for session {self.session_id}: {e}")
async def add_pending_error_message(self, error_message: dict) -> None:
"""

View File

@@ -1,7 +1,9 @@
"""MCP utility functions."""
import asyncio
import json
import re
import uuid
from typing import Tuple
@@ -9,18 +11,83 @@ from fastapi import Request, HTTPException
from gateway.common.constants import (
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
INVARIANT_SESSION_ID_PREFIX,
MCP_CLIENT_INFO,
MCP_SERVER_BASE_URL_HEADER,
MCP_LIST_TOOLS,
MCP_METHOD,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_INFO,
MCP_TOOL_CALL,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
McpSession,
McpSessionsManager,
)
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.mcp.log import format_errors_in_response
def _check_if_new_errors(
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
def generate_session_id() -> str:
"""
Generate a new session ID.
If the MCP server is session less then we don't have a session ID from the MCP server.
"""
return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex
def update_mcp_server_in_session_metadata(
session: McpSession, response_body: dict
) -> None:
"""Update the MCP server information in the session metadata."""
if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.attributes.metadata["mcp_server"] = (
response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None:
"""Updates the tool call ID in the session."""
if request_body.get(MCP_METHOD) and request_body.get("id"):
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
MCP_METHOD
)
def update_mcp_client_info_in_session(session: McpSession, request_body: dict) -> None:
"""Update the MCP client info in the session metadata."""
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
def update_session_from_request(session: McpSession, request_body: dict) -> None:
"""Update the MCP client information and request id in the session."""
update_mcp_client_info_in_session(session, request_body)
update_tool_call_id_in_session(session, request_body)
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
"""
Convert localhost or 127.0.0.1 in an address to host.docker.internal
@@ -73,7 +140,13 @@ async def hook_tool_call(
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
request_body (dict): The request JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
indicating whether the request was blocked. If the request is blocked, the
dict will contain an error message else it will contain the original request.
"""
tool_call = {
"id": f"call_{request_body.get('id')}",
@@ -89,14 +162,13 @@ async def hook_tool_call(
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
)
print("[hook_tool_call] Guardrails result:", guardrails_result, flush=True)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrails_result
and guardrails_result.get("errors", [])
and check_if_new_errors(session_id, session_store, guardrails_result)
and _check_if_new_errors(session_id, session_store, guardrails_result)
):
# Add the trace to the explorer
asyncio.create_task(
@@ -120,24 +192,10 @@ async def hook_tool_call(
return request_body, False
def check_if_new_errors(
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
) -> bool:
"""Checks if there are new errors in the guardrails result."""
session = session_store.get_session(session_id)
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in session.annotations:
return True
return False
async def hook_tool_call_response(
session_id: str,
session_store: McpSessionsManager,
response_json: dict,
response_body: dict,
is_tools_list=False,
) -> dict:
"""
@@ -145,19 +203,21 @@ async def hook_tool_call_response(
Hook to process the response JSON after receiving it from the MCP server.
Args:
session_id (str): The session ID associated with the request.
response_json (dict): The response JSON to be processed.
session_store (McpSessionsManager): The session store to manage sessions.
response_body (dict): The response JSON to be processed.
is_tools_list (bool): Flag to indicate if the response is from a tools/list call.
Returns:
dict: The response JSON is returned if no guardrail is violated
else an error dict is returned.
"""
blocked = False
is_blocked = False
result = response_body
message = {
"role": "tool",
"tool_call_id": f"call_{response_json.get('id')}",
"content": response_json.get(MCP_RESULT).get("content"),
"error": response_json.get(MCP_RESULT).get("error"),
"tool_call_id": f"call_{result.get('id')}",
"content": result.get(MCP_RESULT).get("content"),
"error": result.get(MCP_RESULT).get("error"),
}
result = response_json
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
message, action=GuardrailAction.BLOCK
@@ -166,14 +226,14 @@ async def hook_tool_call_response(
if (
guardrails_result
and guardrails_result.get("errors", [])
and check_if_new_errors(session_id, session_store, guardrails_result)
and _check_if_new_errors(session_id, session_store, guardrails_result)
):
blocked = True
is_blocked = True
# If the request is blocked, return a message indicating the block reason
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"id": response_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
@@ -184,7 +244,7 @@ async def hook_tool_call_response(
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"id": response_body.get("id"),
"result": {
"tools": [
{
@@ -201,13 +261,64 @@ async def hook_tool_call_response(
"title": "This tool was blocked by security guardrails.",
},
}
for tool in response_json["result"]["tools"]
for tool in response_body["result"]["tools"]
]
},
}
# Push trace to the explorer - don't block on its response
asyncio.create_task(
session_store.add_message_to_session(session_id, message, guardrails_result)
)
return result, blocked
# Push trace to the explorer
await session_store.add_message_to_session(session_id, message, guardrails_result)
return result, is_blocked
async def intercept_response(
session_id: str, session_store: McpSessionsManager, response_body: dict
) -> Tuple[dict, bool]:
"""
Intercept the response and check for guardrails.
This function is used to intercept responses and check for guardrails.
If the response is blocked, it returns a message indicating the block
reason with a boolean flag set to True. If the response is not blocked,
it returns the original response with a boolean flag set to False.
Args:
session_id (str): The session ID associated with the request.
session_store (McpSessionsManager): The session store to manage sessions.
response_body (dict): The response JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked.
"""
session = session_store.get_session(session_id)
method = session.id_to_method_mapping.get(response_body.get("id"))
intercept_response_result = response_body
is_blocked = False
# Intercept and potentially block tool call response
if method == MCP_TOOL_CALL:
intercept_response_result, is_blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
)
# Intercept and potentially block list tool call response
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).attributes.metadata["tools"] = (
response_body.get(MCP_RESULT).get("tools")
)
intercept_response_result, is_blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body={
"jsonrpc": "2.0",
"id": response_body.get("id"),
"result": {
"content": json.dumps(response_body.get(MCP_RESULT).get("tools")),
"tools": response_body.get(MCP_RESULT).get("tools"),
},
},
is_tools_list=True,
)
return intercept_response_result, is_blocked

View File

@@ -6,6 +6,7 @@ import json
from typing import Any, Dict, List
from fastapi import HTTPException
from gateway.common.constants import DEFAULT_API_URL
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
@@ -13,8 +14,6 @@ from invariant_sdk.types.annotations import AnnotationCreate
import httpx
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
def create_annotations_from_guardrails_errors(
guardrails_errors: List[dict],

View File

@@ -9,14 +9,13 @@ from functools import wraps
from fastapi import HTTPException
import httpx
from gateway.common.constants import DEFAULT_API_URL
from gateway.common.guardrails import Guardrail
from gateway.common.request_context import RequestContext
from gateway.common.authorization import (
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
)
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
# Timestamps of last API calls per guardrails string
_guardrails_cache = {}

View File

@@ -1,63 +1,36 @@
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
import asyncio
import getpass
import json
import os
import platform
import select
import socket
import subprocess
import sys
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
from invariant_sdk.types.push_traces import PushTracesRequest
from gateway.common.constants import (
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
MCP_METHOD,
MCP_CLIENT_INFO,
MCP_PARAMS,
MCP_SERVER_INFO,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
UTF_8,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.integrations.guardrails import check_guardrails
from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response
from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_sync
from gateway.common.mcp_sessions_manager import (
McpAttributes,
McpSessionsManager,
)
from gateway.common.mcp_utils import (
generate_session_id,
hook_tool_call,
intercept_response,
update_mcp_server_in_session_metadata,
update_session_from_request,
)
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
STATUS_EOF = "eof"
STATUS_DATA = "data"
STATUS_WAIT = "wait"
def user_and_host() -> str:
"""Get the current user and hostname."""
username = getpass.getuser()
hostname = socket.gethostname()
return f"{username}@{hostname}"
def session_metadata(ctx: McpContext) -> dict:
"""Generate metadata for the current session."""
return {
"session_id": ctx.local_session_id,
"system_user": user_and_host(),
"mcp_client": ctx.mcp_client_name,
"mcp_server": ctx.mcp_server_name,
"tools": ctx.tools,
**(ctx.extra_metadata or {}),
}
session_store = McpSessionsManager()
def write_as_utf8_bytes(data: dict) -> bytes:
@@ -65,326 +38,37 @@ def write_as_utf8_bytes(data: dict) -> bytes:
return json.dumps(data).encode(UTF_8) + b"\n"
def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list:
"""Deduplicate new_annotations using the annotations in the context."""
deduped_annotations = []
for annotation in new_annotations:
if annotation not in ctx.annotations:
deduped_annotations.append(annotation)
return deduped_annotations
def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
"""Checks if there are new errors in the guardrails result."""
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in ctx.annotations:
return True
return False
async def append_and_push_trace(
ctx: McpContext, message: dict, guardrails_result: dict
) -> None:
"""
Append a message to the trace if it exists or create a new one
and push it to the Invariant Explorer.
"""
annotations = []
if guardrails_result and guardrails_result.get("errors", []):
annotations = create_annotations_from_guardrails_errors(
guardrails_result["errors"]
)
if ctx.guardrails.logging_guardrails:
logging_guardrails_check_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.LOG
)
if logging_guardrails_check_result and logging_guardrails_check_result.get(
"errors", []
):
annotations.extend(
create_annotations_from_guardrails_errors(
logging_guardrails_check_result["errors"]
)
)
deduplicated_annotations = deduplicate_annotations(ctx, annotations)
try:
# If the trace_id is None, create a new trace with the messages.
# Otherwise, append the message to the existing trace.
client = AsyncClient(
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
)
if ctx.trace_id is None:
ctx.trace.append(message)
# default metadata
metadata = {"source": "mcp"}
# include MCP session metadata
metadata.update(session_metadata(ctx))
response = await client.push_trace(
PushTracesRequest(
messages=[ctx.trace],
dataset=ctx.explorer_dataset,
metadata=[metadata],
annotations=[deduplicated_annotations],
)
)
ctx.trace_id = response.id[0]
ctx.last_trace_length = len(ctx.trace)
ctx.annotations.extend(deduplicated_annotations)
else:
ctx.trace.append(message)
response = await client.append_messages(
AppendMessagesRequest(
trace_id=ctx.trace_id,
messages=ctx.trace[ctx.last_trace_length :],
annotations=deduplicated_annotations,
)
)
ctx.last_trace_length = len(ctx.trace)
ctx.annotations.extend(deduplicated_annotations)
except Exception as e: # pylint: disable=broad-except
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
async def get_guardrails_check_result(
ctx: McpContext,
message: dict,
action: GuardrailAction = GuardrailAction.BLOCK,
) -> dict:
"""
Check against guardrails of type action.
Works in both sync and async contexts by always using a dedicated thread.
"""
# Skip if no guardrails are configured for this action
if not (
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
):
return {}
# Prepare context and select appropriate guardrails
context = RequestContext.create(
request_json={},
dataset_name=ctx.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=ctx.guardrails,
guardrails_parameters={"metadata": session_metadata(ctx), "action": action},
)
guardrails_to_check = (
ctx.guardrails.blocking_guardrails
if action == GuardrailAction.BLOCK
else ctx.guardrails.logging_guardrails
)
return run_task_sync(
check_guardrails,
messages=ctx.trace + [message],
guardrails=guardrails_to_check,
context=context,
)
def json_rpc_error_response(
id_value: str | int, error_message: str, response_type: str = "error"
) -> dict:
"""
Create a JSON-RPC error response with either error object or content format.
Args:
id_value: The ID of the JSON-RPC request
error_message: The error message to include
response_type: Either "error" or "content" to determine response format
Returns:
A properly formatted JSON-RPC response dictionary
"""
base_response = {
"jsonrpc": "2.0",
"id": id_value,
}
if response_type == "error":
base_response["error"] = {
"code": -32600,
"message": error_message,
}
else:
base_response["result"] = {
"content": [
{
"type": "text",
"text": error_message,
}
]
}
return base_response
async def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
"""
Hook function to intercept tool calls.
If the request is blocked, it returns a tuple with a message explaining the block
and a flag indicating the request was blocked.
Otherwise it returns the original request and a flag indicating it was not blocked.
"""
tool_call = {
"id": f"call_{request.get('id')}",
"type": "function",
"function": {
"name": request["params"]["name"],
"arguments": request["params"]["arguments"],
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails
guardrailing_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
if (
guardrailing_result
and guardrailing_result.get("errors", [])
and check_if_new_errors(ctx, guardrailing_result)
):
if ctx.push_explorer:
await append_and_push_trace(ctx, message, guardrailing_result)
else:
ctx.trace.append(message)
return json_rpc_error_response(
request.get("id"),
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrailing_result["errors"],
response_type=ctx.failure_response_format,
), True
# Add the message to the trace
ctx.trace.append(message)
return request, False
async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
"""
Hook function to intercept tool results.
Returns the potentially modified result.
"""
method = ctx.id_to_method_mapping.get(result.get("id"))
call_id = f"call_{result.get('id')}"
# Safely handle result object
result_obj = result.get("result", {})
if result_obj.get(MCP_SERVER_INFO):
ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "")
if not method:
return result
elif method == MCP_TOOL_CALL:
message = {
"role": "tool",
"content": result_obj.get("content"),
"error": result_obj.get("error"),
"tool_call_id": call_id,
}
# Check for blocking guardrails
guardrailing_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
if guardrailing_result and guardrailing_result.get("errors", []):
result = json_rpc_error_response(
result.get("id"),
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% format_errors_in_response(guardrailing_result["errors"]),
response_type=ctx.failure_response_format, # Using content type as that's what the original code used
)
if ctx.push_explorer:
await append_and_push_trace(ctx, message, guardrailing_result)
else:
ctx.trace.append(message)
return result
elif method == MCP_LIST_TOOLS:
ctx.tools = result_obj.get("tools")
message = {
"role": "tool",
"content": json.dumps(result.get("result").get("tools")),
"tool_call_id": call_id,
}
# next validate it with guardrails
guardrailing_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
if guardrailing_result and guardrailing_result.get("errors", []):
result["result"]["tools"] = [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
% format_errors_in_response(guardrailing_result["errors"]),
# no parameters
"inputSchema": {
"properties": {},
"required": [],
"title": "invariant_mcp_server_blockedArguments",
"type": "object",
},
"annotations": {
"title": "This tool was blocked by security guardrails.",
},
}
for tool in result["result"]["tools"]
]
# add it to the session trace (and run logging guardrails)
if ctx.push_explorer:
await append_and_push_trace(ctx, message, guardrailing_result)
else:
ctx.trace.append(message)
return result
else:
return result
async def stream_and_forward_stdout(
mcp_process: subprocess.Popen, ctx: McpContext
session_id: str, mcp_process: subprocess.Popen
) -> None:
"""Read from the mcp_process stdout, apply guardrails and forward to sys.stdout"""
loop = asyncio.get_event_loop()
while True:
if mcp_process.poll() is not None:
mcp_log(f"[ERROR] MCP process terminated with code: {mcp_process.poll()}")
break
line = await loop.run_in_executor(None, mcp_process.stdout.readline)
if not line:
break
try:
# Process complete JSON lines
line_str = line.decode(UTF_8).strip()
if not line_str:
decoded_line = line.decode(UTF_8).strip()
if not decoded_line:
continue
session = session_store.get_session(session_id)
if session.attributes.verbose:
mcp_log(f"[INFO] server -> client: {decoded_line}")
response_body = json.loads(decoded_line)
update_mcp_server_in_session_metadata(session, response_body)
if ctx.verbose:
mcp_log(f"[INFO] server -> client: {line_str}")
parsed_json = json.loads(line_str)
processed_json = await hook_tool_result(ctx, parsed_json)
intercept_response_result, _ = await intercept_response(
session_id, session_store, response_body
)
# Write and flush immediately
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
sys.stdout.buffer.write(write_as_utf8_bytes(intercept_response_result))
sys.stdout.buffer.flush()
except Exception as e: # pylint: disable=broad-except
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
if line:
@@ -407,74 +91,52 @@ async def stream_and_forward_stderr(
MCP_LOG_FILE.buffer.flush()
async def process_line(
ctx: McpContext, mcp_process: subprocess.Popen, line: bytes
async def _intercept_request(
session_id: str, mcp_process: subprocess.Popen, line: bytes
) -> None:
"""Process a line of input from stdin, decode it, and forward to mcp_process."""
if ctx.verbose:
"""
Process a line of input from stdin, decode it and check for guardrails.
If the request is blocked, it returns a message indicating the block reason
otherwise it forwards the request to mcp_process stdin.
"""
session = session_store.get_session(session_id)
if session.attributes.verbose:
mcp_log(f"[INFO] client -> server: {line}")
# Try to decode and parse as JSON to check for tool calls
try:
text = line.decode(UTF_8)
parsed_json = json.loads(text)
request_body = json.loads(text)
except json.JSONDecodeError as je:
mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}")
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
return
update_session_from_request(session, request_body)
# Refresh guardrails
await session.load_guardrails()
if parsed_json.get(MCP_METHOD) is not None:
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(MCP_METHOD)
if parsed_json.get(MCP_PARAMS) and parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO):
ctx.mcp_client_name = (
parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
hook_tool_call_result = {}
is_blocked = False
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id, session_store, request_body
)
# Check if this is a tool call request
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
# Refresh guardrails
run_task_sync(ctx.load_guardrails)
# Intercept and potentially block modify the request
hook_tool_call_result, is_blocked = await hook_tool_call(ctx, parsed_json)
if not is_blocked:
# If blocked, hook_tool_call_result contains the original request.
# Forward the request to the MCP process.
# It will handle the request and return a response.
mcp_process.stdin.write(write_as_utf8_bytes(hook_tool_call_result))
mcp_process.stdin.flush()
else:
# If blocked, hook_tool_call_result contains the block message.
# Forward the block message result back to the caller.
# The original request is not passed to the MCP process.
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
sys.stdout.buffer.flush()
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
hook_tool_call_result, is_blocked = await hook_tool_call(
session_id=session_id,
session_store=session_store,
request_body={
"id": request_body.get("id"),
"method": MCP_LIST_TOOLS,
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
},
)
if is_blocked:
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
sys.stdout.buffer.flush()
return
else:
# pass through the request to the MCP process
# for list_tools, extend the trace by a tool call
if parsed_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
# Refresh guardrails
run_task_sync(ctx.load_guardrails)
ctx.trace.append(
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": f"call_{parsed_json.get('id')}",
"type": "function",
"function": {
"name": "tools/list",
"arguments": {},
},
}
],
}
)
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
mcp_process.stdin.flush()
mcp_process.stdin.write(write_as_utf8_bytes(request_body))
mcp_process.stdin.flush()
async def wait_for_stdin_input(
@@ -523,7 +185,7 @@ async def wait_for_stdin_input(
async def run_stdio_input_loop(
ctx: McpContext,
session_id: str,
mcp_process: subprocess.Popen,
stdout_task: asyncio.Task,
stderr_task: asyncio.Task,
@@ -557,7 +219,7 @@ async def run_stdio_input_loop(
if not line:
continue
await process_line(ctx, mcp_process, line)
await _intercept_request(session_id, mcp_process, line)
except (BrokenPipeError, KeyboardInterrupt):
# Broken pipe = client disappeared, just start shutdown
mcp_log("Client disconnected or keyboard interrupt")
@@ -570,7 +232,7 @@ async def run_stdio_input_loop(
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if line:
await process_line(ctx, mcp_process, line)
await _intercept_request(session_id, mcp_process, line)
# Terminate process if needed
if mcp_process.poll() is None:
@@ -631,7 +293,11 @@ async def execute(args: list[str] = None):
mcp_log("[INFO] Running with Python version:", sys.version)
mcp_gateway_args, mcp_server_command_args = split_args(args)
ctx = McpContext(mcp_gateway_args)
session_id = generate_session_id()
await session_store.initialize_session(
session_id,
McpAttributes.from_cli_args(mcp_gateway_args),
)
mcp_process = subprocess.Popen(
mcp_server_command_args,
@@ -642,8 +308,10 @@ async def execute(args: list[str] = None):
)
# Start async tasks for stdout and stderr
stdout_task = asyncio.create_task(stream_and_forward_stdout(mcp_process, ctx))
stdout_task = asyncio.create_task(
stream_and_forward_stdout(session_id, mcp_process)
)
stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process))
# Handle forwarding stdin and intercept tool calls
await run_stdio_input_loop(ctx, mcp_process, stdout_task, stderr_task)
await run_stdio_input_loop(session_id, mcp_process, stdout_task, stderr_task)

View File

@@ -1,112 +0,0 @@
"""Context manager for MCP (Model Context Protocol) gateway."""
import argparse
import os
import random
import uuid
from typing import Dict
from gateway.integrations.explorer import (
fetch_guardrails_from_explorer,
)
from gateway.common.guardrails import GuardrailRuleSet
class McpContext:
"""Singleton class to manage MCP context and state."""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(McpContext, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, cli_args: list):
if not hasattr(self, "_initialized"):
self._initialized = False
if self._initialized:
return
config, extra_args = self._parse_cli_args(cli_args)
# The project name is used to identify the dataset in Invariant Explorer.
self.explorer_dataset = config.project_name
# whether to push traces to Invariant Explorer
self.push_explorer = config.push_explorer
# the format to use to communicate guardrail failures to the client
self.failure_response_format = config.failure_response_format
# verbose logging of in/out
self.verbose = config.verbose
# trace of this MCP session
self.trace = []
# tools available to the MCP server
self.tools = []
# configured guardrails
self.guardrails = GuardrailRuleSet(
blocking_guardrails=[], logging_guardrails=[]
)
# parsed from CLI (all --metadata-* args)
self.extra_metadata: Dict[str, str] = {}
for arg in extra_args:
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
key, value = arg.split("=")
assert key.startswith(
"--metadata-"
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
key = key[len("--metadata-") :]
self.extra_metadata[key] = value
# captured from MCP calls/responses
self.mcp_client_name = ""
self.mcp_server_name = ""
# We send the same trace messages for guardrails analysis multiple times.
# We need to deduplicate them before sending to the explorer.
self.annotations = []
self.trace_id = None
self.local_session_id = str(uuid.uuid4())
self.last_trace_length = 0
self.id_to_method_mapping = {}
self._initialized = True
def _parse_cli_args(self, cli_args: list) -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="MCP Gateway")
parser.add_argument(
"--project-name",
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
type=str,
default=f"mcp-capture-{random.randint(1, 100)}",
)
parser.add_argument(
"--push-explorer",
help="Enable pushing traces to Invariant Explorer",
action="store_true",
)
parser.add_argument(
"--verbose",
help="Enable verbose logging",
action="store_true",
)
parser.add_argument(
"--failure-response-format",
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
type=str,
default="error",
)
return parser.parse_known_args(cli_args)
async def load_guardrails(self):
"""Run async setup logic (e.g. fetching guardrails)."""
self.guardrails = await fetch_guardrails_from_explorer(
self.explorer_dataset,
"Bearer " + os.getenv("INVARIANT_API_KEY"),
self.extra_metadata.get("client", self.mcp_client_name),
self.extra_metadata.get("server", self.mcp_server_name),
)

View File

@@ -15,10 +15,6 @@ from gateway.common.constants import (
MCP_METHOD,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_INFO,
MCP_CLIENT_INFO,
UTF_8,
)
from gateway.common.mcp_sessions_manager import (
@@ -28,7 +24,9 @@ from gateway.common.mcp_sessions_manager import (
from gateway.common.mcp_utils import (
get_mcp_server_base_url,
hook_tool_call,
hook_tool_call_response,
intercept_response,
update_mcp_server_in_session_metadata,
update_session_from_request,
)
MCP_SERVER_POST_HEADERS = {
@@ -72,16 +70,7 @@ async def mcp_post_sse_gateway(
request_body_bytes = await request.body()
request_body = json.loads(request_body_bytes)
session = session_store.get_session(session_id)
if request_body.get(MCP_METHOD) and request_body.get("id"):
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
MCP_METHOD
)
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
update_session_from_request(session, request_body)
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
# Intercept and potentially block the request
@@ -137,8 +126,6 @@ async def mcp_post_sse_gateway(
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
except Exception as e:
import traceback
traceback.print_exc()
print(f"[MCP POST] Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Unexpected error") from e
@@ -340,59 +327,18 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8)
session = session_store.get_session(session_id)
try:
response_json = json.loads(sse.data)
response_body = json.loads(sse.data)
update_mcp_server_in_session_metadata(session, response_body)
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.attributes.metadata["mcp_server"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
intercept_response_result, is_blocked = await intercept_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
)
if is_blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(intercept_response_result)}\n\n".encode(
UTF_8
)
method = session.id_to_method_mapping.get(response_json.get("id"))
if method == MCP_TOOL_CALL:
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json=response_json,
)
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = (
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
)
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json={
"id": response_json.get("id"),
"result": {
"content": json.dumps(
response_json.get(MCP_RESULT).get("tools")
),
"tools": response_json.get(MCP_RESULT).get("tools"),
},
},
is_tools_list=True,
)
# Update the event bytes with hook_tool_call_response.
# hook_tool_call_response is same as response_json if no guardrail is violated.
# If guardrail is violated, it contains the error message.
# pylint: disable=line-too-long
if blocked:
event_bytes = (
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
)
except json.JSONDecodeError as e:
print(
f"[MCP SSE] Error parsing message JSON: {e}",

View File

@@ -1,9 +1,6 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
import json
import uuid
from typing import Tuple
import httpx
@@ -13,12 +10,8 @@ from fastapi.responses import StreamingResponse
from gateway.common.constants import (
CLIENT_TIMEOUT,
INVARIANT_SESSION_ID_PREFIX,
MCP_CLIENT_INFO,
MCP_LIST_TOOLS,
MCP_METHOD,
MCP_PARAMS,
MCP_RESULT,
MCP_SERVER_INFO,
MCP_TOOL_CALL,
UTF_8,
)
@@ -27,9 +20,13 @@ from gateway.common.mcp_sessions_manager import (
McpAttributes,
)
from gateway.common.mcp_utils import (
generate_session_id,
get_mcp_server_base_url,
hook_tool_call,
hook_tool_call_response,
intercept_response,
update_mcp_client_info_in_session,
update_mcp_server_in_session_metadata,
update_tool_call_id_in_session,
)
gateway = APIRouter()
@@ -69,7 +66,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
# If a session ID is provided in the request headers, it was already initialized
# in McpSessionsManager. This might be a session ID returned by the MCP server
# or a session ID generated in the gateway.
_update_tool_call_id_in_session(session_id, request_body)
update_tool_call_id_in_session(
session_store.get_session(session_id), request_body
)
elif is_initialization_request:
# If this is an initialization request, we generate a session ID,
# We don't call initialize_session here because we don't know
@@ -77,7 +76,7 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
# If later in the response from MCP server, we don't receive a session ID then this
# will be initialized and returned back to the client else this will be
# overwritten by the session ID returned by the MCP server.
session_id = _generate_session_id()
session_id = generate_session_id()
# Intercept the request and check for guardrails.
if not is_initialization_request:
@@ -115,9 +114,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
# Update client info if this is an initialization request
if is_initialization_request:
_update_mcp_client_info_in_session(
session_id=session_id,
request_body=request_body,
update_mcp_client_info_in_session(
session_store.get_session(session_id),
request_body,
)
# If the response is JSON type, handle it as a JSON response.
@@ -281,52 +280,17 @@ def _get_mcp_server_endpoint(request: Request) -> str:
return get_mcp_server_base_url(request) + "/mcp/"
def _generate_session_id() -> str:
"""
Generate a new session ID.
If the MCP server is session less then we don't have a session ID from the MCP server.
"""
return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex
def _update_tool_call_id_in_session(session_id: str, request_body: dict) -> None:
"""
Updates the tool call ID in the session.
"""
session = session_store.get_session(session_id)
if request_body.get(MCP_METHOD) and request_body.get("id"):
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
MCP_METHOD
)
def _update_mcp_client_info_in_session(session_id: str, request_body: dict) -> None:
"""
Update the MCP client info in the session metadata.
"""
session = session_store.get_session(session_id)
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
def _update_mcp_response_info_in_session(
session_id: str, response_json: dict, is_json_response: bool
session_id: str, response_body: dict, is_json_response: bool
) -> None:
"""
Update the MCP response info in the session metadata.
"""
session = session_store.get_session(session_id)
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.attributes.metadata["mcp_server"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
session.attributes.metadata["server_response_type"] = "json" if is_json_response else "sse"
update_mcp_server_in_session_metadata(session, response_body)
session.attributes.metadata["server_response_type"] = (
"json" if is_json_response else "sse"
)
def _is_initialization_request(request_body: dict) -> bool:
@@ -353,18 +317,20 @@ async def _handle_mcp_json_response(
# return the error message else return the response as is
response_content = response.content
# The server response is empty string when client sends "notifications/initialized"
response_json = (
response_body = (
json.loads(response_content.decode(UTF_8)) if response_content else {}
)
if response_json:
if response_body:
_update_mcp_response_info_in_session(
session_id=session_id, response_json=response_json, is_json_response=True
session_id=session_id, response_body=response_body, is_json_response=True
)
response_code = response.status_code
if not is_initialization_request:
intercept_response_result, blocked = await _intercept_response(
session_id=session_id, response_json=response_json
intercept_response_result, blocked = await intercept_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
)
if blocked:
response_content = json.dumps(intercept_response_result).encode(UTF_8)
@@ -405,14 +371,15 @@ async def _handle_mcp_streaming_response(
if not stripped_line:
break # End of stream
if buffer:
response_json = json.loads(stripped_line.split("data: ")[1].strip())
response_body = json.loads(stripped_line.split("data: ")[1].strip())
if not is_initialization_request:
(
intercept_response_result,
blocked,
) = await _intercept_response(
) = await intercept_response(
session_id=session_id,
response_json=response_json,
session_store=session_store,
response_body=response_body,
)
if blocked:
yield (
@@ -423,7 +390,7 @@ async def _handle_mcp_streaming_response(
else:
_update_mcp_response_info_in_session(
session_id=session_id,
response_json=response_json,
response_body=response_body,
is_json_response=False,
)
yield f"{buffer}\n{stripped_line}\n\n"
@@ -482,46 +449,3 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response |
media_type="application/json",
)
return None
async def _intercept_response(
session_id: str, response_json: dict
) -> Tuple[dict, bool]:
"""
Intercept the response and check for guardrails.
This function is used to intercept responses and check for guardrails.
If the response is blocked, it returns a message indicating the block
reason with a boolean flag set to True. If the response is not blocked,
it returns the original response with a boolean flag set to False.
"""
session = session_store.get_session(session_id)
method = session.id_to_method_mapping.get(response_json.get("id"))
# Intercept and potentially block tool call response
if method == MCP_TOOL_CALL:
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json=response_json,
)
return result, blocked
# Intercept and potentially block list tool call response
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
result, blocked = await hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_json={
"id": response_json.get("id"),
"result": {
"content": json.dumps(response_json.get(MCP_RESULT).get("tools")),
"tools": response_json.get(MCP_RESULT).get("tools"),
},
},
is_tools_list=True,
)
return result, blocked
return response_json, False

View File

@@ -1,5 +1,6 @@
"""A MCP client implementation that interacts with MCP server to make tool calls."""
import asyncio
import os
from datetime import timedelta
@@ -48,7 +49,6 @@ class MCPClient:
for key, value in metadata_keys.items():
args.append("--metadata-" + key + "=" + value)
if push_to_explorer:
args.append("--push-explorer")
args.extend(
@@ -133,7 +133,7 @@ async def run(
project_name,
server_script_path,
push_to_explorer,
metadata_keys=metadata_keys
metadata_keys=metadata_keys,
)
listed_tools = await client.session.list_tools()
if tool_name == "tools/list":
@@ -142,4 +142,6 @@ async def run(
else:
return await client.call_tool(tool_name, tool_args)
finally:
if push_to_explorer:
await asyncio.sleep(2)
await client.cleanup()