mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Refactor stdio implementation to use McpSession class.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Common constants used in the gateway."""
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
IGNORED_HEADERS = [
|
||||
"accept-encoding",
|
||||
"host",
|
||||
|
||||
@@ -16,7 +16,7 @@ from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from gateway.common.constants import INVARIANT_SESSION_ID_PREFIX
|
||||
from gateway.common.constants import DEFAULT_API_URL, INVARIANT_SESSION_ID_PREFIX
|
||||
from gateway.common.guardrails import GuardrailRuleSet, GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.integrations.explorer import (
|
||||
@@ -25,8 +25,6 @@ from gateway.integrations.explorer import (
|
||||
)
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
def user_and_host() -> str:
|
||||
"""Get the current user and hostname."""
|
||||
@@ -61,30 +59,15 @@ class McpSession(BaseModel):
|
||||
# and other session-related operations
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
def get_invariant_api_key(self) -> str:
|
||||
def _get_invariant_api_key(self) -> str:
|
||||
"""Get the Invariant API key for the session."""
|
||||
if self.attributes.invariant_api_key:
|
||||
return self.attributes.invariant_api_key
|
||||
return os.getenv("INVARIANT_API_KEY")
|
||||
|
||||
def get_invariant_authorization(self) -> str:
|
||||
def _get_invariant_authorization(self) -> str:
|
||||
"""Get the Invariant authorization header for the session."""
|
||||
return "Bearer " + self.get_invariant_api_key()
|
||||
|
||||
async def load_guardrails(self) -> None:
|
||||
"""
|
||||
Load guardrails for the session.
|
||||
|
||||
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
|
||||
"""
|
||||
print("Inside load_guardrails attributes: ", self.attributes, flush=True)
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.attributes.explorer_dataset,
|
||||
self.get_invariant_authorization(),
|
||||
# pylint: disable=no-member
|
||||
self.attributes.metadata.get("mcp_client"),
|
||||
self.attributes.metadata.get("mcp_server"),
|
||||
)
|
||||
return "Bearer " + self._get_invariant_api_key()
|
||||
|
||||
def _deduplicate_annotations(self, new_annotations: list) -> list:
|
||||
"""Deduplicate new_annotations using the annotations in the session."""
|
||||
@@ -94,6 +77,20 @@ class McpSession(BaseModel):
|
||||
deduped_annotations.append(annotation)
|
||||
return deduped_annotations
|
||||
|
||||
async def load_guardrails(self) -> None:
|
||||
"""
|
||||
Load guardrails for the session.
|
||||
|
||||
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
|
||||
"""
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.attributes.explorer_dataset,
|
||||
self._get_invariant_authorization(),
|
||||
# pylint: disable=no-member
|
||||
self.attributes.metadata.get("mcp_client"),
|
||||
self.attributes.metadata.get("mcp_server"),
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def session_lock(self):
|
||||
"""
|
||||
@@ -134,15 +131,10 @@ class McpSession(BaseModel):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
print(
|
||||
"Inside get_guardrails_check_result attributes: ",
|
||||
self.attributes,
|
||||
flush=True,
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=self.attributes.explorer_dataset,
|
||||
invariant_authorization=self.get_invariant_authorization(),
|
||||
invariant_authorization=self._get_invariant_authorization(),
|
||||
guardrails=self.guardrails,
|
||||
guardrails_parameters={
|
||||
"metadata": self.session_metadata(),
|
||||
@@ -210,11 +202,10 @@ class McpSession(BaseModel):
|
||||
|
||||
This is an internal method that should only be called within a lock.
|
||||
"""
|
||||
print("Inside _push_trace_update attributes: ", self.attributes, flush=True)
|
||||
try:
|
||||
client = AsyncClient(
|
||||
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
|
||||
api_key=self.get_invariant_api_key(),
|
||||
api_key=self._get_invariant_api_key(),
|
||||
)
|
||||
|
||||
# If no trace exists, create a new one
|
||||
@@ -247,7 +238,7 @@ class McpSession(BaseModel):
|
||||
self.annotations.extend(deduplicated_annotations)
|
||||
self.last_trace_length = len(self.messages)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}")
|
||||
print(f"[MCP] Error pushing trace for session {self.session_id}: {e}")
|
||||
|
||||
async def add_pending_error_message(self, error_message: dict) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""MCP utility functions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
@@ -9,18 +11,83 @@ from fastapi import Request, HTTPException
|
||||
from gateway.common.constants import (
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
INVARIANT_SESSION_ID_PREFIX,
|
||||
MCP_CLIENT_INFO,
|
||||
MCP_SERVER_BASE_URL_HEADER,
|
||||
MCP_LIST_TOOLS,
|
||||
MCP_METHOD,
|
||||
MCP_PARAMS,
|
||||
MCP_RESULT,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_TOOL_CALL,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSession,
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.mcp.log import format_errors_in_response
|
||||
|
||||
|
||||
def _check_if_new_errors(
|
||||
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
|
||||
) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
session = session_store.get_session(session_id)
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_session_id() -> str:
|
||||
"""
|
||||
Generate a new session ID.
|
||||
If the MCP server is session less then we don't have a session ID from the MCP server.
|
||||
"""
|
||||
return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex
|
||||
|
||||
|
||||
def update_mcp_server_in_session_metadata(
|
||||
session: McpSession, response_body: dict
|
||||
) -> None:
|
||||
"""Update the MCP server information in the session metadata."""
|
||||
if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.attributes.metadata["mcp_server"] = (
|
||||
response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
|
||||
|
||||
def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None:
|
||||
"""Updates the tool call ID in the session."""
|
||||
if request_body.get(MCP_METHOD) and request_body.get("id"):
|
||||
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
|
||||
|
||||
def update_mcp_client_info_in_session(session: McpSession, request_body: dict) -> None:
|
||||
"""Update the MCP client info in the session metadata."""
|
||||
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.attributes.metadata["mcp_client"] = (
|
||||
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
|
||||
def update_session_from_request(session: McpSession, request_body: dict) -> None:
|
||||
"""Update the MCP client information and request id in the session."""
|
||||
update_mcp_client_info_in_session(session, request_body)
|
||||
update_tool_call_id_in_session(session, request_body)
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
"""
|
||||
Convert localhost or 127.0.0.1 in an address to host.docker.internal
|
||||
@@ -73,7 +140,13 @@ async def hook_tool_call(
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
session_store (McpSessionsManager): The session store to manage sessions.
|
||||
request_body (dict): The request JSON to be processed.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
|
||||
indicating whether the request was blocked. If the request is blocked, the
|
||||
dict will contain an error message else it will contain the original request.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request_body.get('id')}",
|
||||
@@ -89,14 +162,13 @@ async def hook_tool_call(
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
print("[hook_tool_call] Guardrails result:", guardrails_result, flush=True)
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and check_if_new_errors(session_id, session_store, guardrails_result)
|
||||
and _check_if_new_errors(session_id, session_store, guardrails_result)
|
||||
):
|
||||
# Add the trace to the explorer
|
||||
asyncio.create_task(
|
||||
@@ -120,24 +192,10 @@ async def hook_tool_call(
|
||||
return request_body, False
|
||||
|
||||
|
||||
def check_if_new_errors(
|
||||
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
|
||||
) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
session = session_store.get_session(session_id)
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def hook_tool_call_response(
|
||||
session_id: str,
|
||||
session_store: McpSessionsManager,
|
||||
response_json: dict,
|
||||
response_body: dict,
|
||||
is_tools_list=False,
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -145,19 +203,21 @@ async def hook_tool_call_response(
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
response_json (dict): The response JSON to be processed.
|
||||
session_store (McpSessionsManager): The session store to manage sessions.
|
||||
response_body (dict): The response JSON to be processed.
|
||||
is_tools_list (bool): Flag to indicate if the response is from a tools/list call.
|
||||
Returns:
|
||||
dict: The response JSON is returned if no guardrail is violated
|
||||
else an error dict is returned.
|
||||
"""
|
||||
blocked = False
|
||||
is_blocked = False
|
||||
result = response_body
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
"content": response_json.get(MCP_RESULT).get("content"),
|
||||
"error": response_json.get(MCP_RESULT).get("error"),
|
||||
"tool_call_id": f"call_{result.get('id')}",
|
||||
"content": result.get(MCP_RESULT).get("content"),
|
||||
"error": result.get(MCP_RESULT).get("error"),
|
||||
}
|
||||
result = response_json
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
@@ -166,14 +226,14 @@ async def hook_tool_call_response(
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and check_if_new_errors(session_id, session_store, guardrails_result)
|
||||
and _check_if_new_errors(session_id, session_store, guardrails_result)
|
||||
):
|
||||
blocked = True
|
||||
is_blocked = True
|
||||
# If the request is blocked, return a message indicating the block reason
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"id": response_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
@@ -184,7 +244,7 @@ async def hook_tool_call_response(
|
||||
# special error response for tools/list tool call
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"id": response_body.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
@@ -201,13 +261,64 @@ async def hook_tool_call_response(
|
||||
"title": "This tool was blocked by security guardrails.",
|
||||
},
|
||||
}
|
||||
for tool in response_json["result"]["tools"]
|
||||
for tool in response_body["result"]["tools"]
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result, blocked
|
||||
# Push trace to the explorer
|
||||
await session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
return result, is_blocked
|
||||
|
||||
|
||||
async def intercept_response(
|
||||
session_id: str, session_store: McpSessionsManager, response_body: dict
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Intercept the response and check for guardrails.
|
||||
This function is used to intercept responses and check for guardrails.
|
||||
If the response is blocked, it returns a message indicating the block
|
||||
reason with a boolean flag set to True. If the response is not blocked,
|
||||
it returns the original response with a boolean flag set to False.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
session_store (McpSessionsManager): The session store to manage sessions.
|
||||
response_body (dict): The response JSON to be processed.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, bool]: A tuple containing the processed response JSON
|
||||
and a boolean indicating whether the response was blocked.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
method = session.id_to_method_mapping.get(response_body.get("id"))
|
||||
|
||||
intercept_response_result = response_body
|
||||
is_blocked = False
|
||||
# Intercept and potentially block tool call response
|
||||
if method == MCP_TOOL_CALL:
|
||||
intercept_response_result, is_blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body=response_body,
|
||||
)
|
||||
# Intercept and potentially block list tool call response
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
session_store.get_session(session_id).attributes.metadata["tools"] = (
|
||||
response_body.get(MCP_RESULT).get("tools")
|
||||
)
|
||||
intercept_response_result, is_blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body={
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_body.get("id"),
|
||||
"result": {
|
||||
"content": json.dumps(response_body.get(MCP_RESULT).get("tools")),
|
||||
"tools": response_body.get(MCP_RESULT).get("tools"),
|
||||
},
|
||||
},
|
||||
is_tools_list=True,
|
||||
)
|
||||
return intercept_response_result, is_blocked
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
|
||||
from gateway.common.constants import DEFAULT_API_URL
|
||||
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
||||
@@ -13,8 +14,6 @@ from invariant_sdk.types.annotations import AnnotationCreate
|
||||
|
||||
import httpx
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
def create_annotations_from_guardrails_errors(
|
||||
guardrails_errors: List[dict],
|
||||
|
||||
@@ -9,14 +9,13 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import httpx
|
||||
|
||||
from gateway.common.constants import DEFAULT_API_URL
|
||||
from gateway.common.guardrails import Guardrail
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.common.authorization import (
|
||||
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
|
||||
)
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
# Timestamps of last API calls per guardrails string
|
||||
_guardrails_cache = {}
|
||||
|
||||
@@ -1,63 +1,36 @@
|
||||
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import select
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
|
||||
from gateway.common.constants import (
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
MCP_METHOD,
|
||||
MCP_CLIENT_INFO,
|
||||
MCP_PARAMS,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response
|
||||
from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_sync
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpAttributes,
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
generate_session_id,
|
||||
hook_tool_call,
|
||||
intercept_response,
|
||||
update_mcp_server_in_session_metadata,
|
||||
update_session_from_request,
|
||||
)
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
STATUS_EOF = "eof"
|
||||
STATUS_DATA = "data"
|
||||
STATUS_WAIT = "wait"
|
||||
|
||||
|
||||
def user_and_host() -> str:
|
||||
"""Get the current user and hostname."""
|
||||
username = getpass.getuser()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
return f"{username}@{hostname}"
|
||||
|
||||
|
||||
def session_metadata(ctx: McpContext) -> dict:
|
||||
"""Generate metadata for the current session."""
|
||||
return {
|
||||
"session_id": ctx.local_session_id,
|
||||
"system_user": user_and_host(),
|
||||
"mcp_client": ctx.mcp_client_name,
|
||||
"mcp_server": ctx.mcp_server_name,
|
||||
"tools": ctx.tools,
|
||||
**(ctx.extra_metadata or {}),
|
||||
}
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
|
||||
def write_as_utf8_bytes(data: dict) -> bytes:
|
||||
@@ -65,326 +38,37 @@ def write_as_utf8_bytes(data: dict) -> bytes:
|
||||
return json.dumps(data).encode(UTF_8) + b"\n"
|
||||
|
||||
|
||||
def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list:
|
||||
"""Deduplicate new_annotations using the annotations in the context."""
|
||||
deduped_annotations = []
|
||||
for annotation in new_annotations:
|
||||
if annotation not in ctx.annotations:
|
||||
deduped_annotations.append(annotation)
|
||||
return deduped_annotations
|
||||
|
||||
|
||||
def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in ctx.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def append_and_push_trace(
|
||||
ctx: McpContext, message: dict, guardrails_result: dict
|
||||
) -> None:
|
||||
"""
|
||||
Append a message to the trace if it exists or create a new one
|
||||
and push it to the Invariant Explorer.
|
||||
"""
|
||||
annotations = []
|
||||
if guardrails_result and guardrails_result.get("errors", []):
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result["errors"]
|
||||
)
|
||||
|
||||
if ctx.guardrails.logging_guardrails:
|
||||
logging_guardrails_check_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.LOG
|
||||
)
|
||||
if logging_guardrails_check_result and logging_guardrails_check_result.get(
|
||||
"errors", []
|
||||
):
|
||||
annotations.extend(
|
||||
create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_check_result["errors"]
|
||||
)
|
||||
)
|
||||
deduplicated_annotations = deduplicate_annotations(ctx, annotations)
|
||||
|
||||
try:
|
||||
# If the trace_id is None, create a new trace with the messages.
|
||||
# Otherwise, append the message to the existing trace.
|
||||
client = AsyncClient(
|
||||
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
|
||||
)
|
||||
|
||||
if ctx.trace_id is None:
|
||||
ctx.trace.append(message)
|
||||
|
||||
# default metadata
|
||||
metadata = {"source": "mcp"}
|
||||
# include MCP session metadata
|
||||
metadata.update(session_metadata(ctx))
|
||||
|
||||
response = await client.push_trace(
|
||||
PushTracesRequest(
|
||||
messages=[ctx.trace],
|
||||
dataset=ctx.explorer_dataset,
|
||||
metadata=[metadata],
|
||||
annotations=[deduplicated_annotations],
|
||||
)
|
||||
)
|
||||
ctx.trace_id = response.id[0]
|
||||
ctx.last_trace_length = len(ctx.trace)
|
||||
ctx.annotations.extend(deduplicated_annotations)
|
||||
else:
|
||||
ctx.trace.append(message)
|
||||
response = await client.append_messages(
|
||||
AppendMessagesRequest(
|
||||
trace_id=ctx.trace_id,
|
||||
messages=ctx.trace[ctx.last_trace_length :],
|
||||
annotations=deduplicated_annotations,
|
||||
)
|
||||
)
|
||||
ctx.last_trace_length = len(ctx.trace)
|
||||
ctx.annotations.extend(deduplicated_annotations)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
ctx: McpContext,
|
||||
message: dict,
|
||||
action: GuardrailAction = GuardrailAction.BLOCK,
|
||||
) -> dict:
|
||||
"""
|
||||
Check against guardrails of type action.
|
||||
Works in both sync and async contexts by always using a dedicated thread.
|
||||
"""
|
||||
# Skip if no guardrails are configured for this action
|
||||
if not (
|
||||
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
|
||||
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=ctx.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=ctx.guardrails,
|
||||
guardrails_parameters={"metadata": session_metadata(ctx), "action": action},
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
ctx.guardrails.blocking_guardrails
|
||||
if action == GuardrailAction.BLOCK
|
||||
else ctx.guardrails.logging_guardrails
|
||||
)
|
||||
|
||||
return run_task_sync(
|
||||
check_guardrails,
|
||||
messages=ctx.trace + [message],
|
||||
guardrails=guardrails_to_check,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
def json_rpc_error_response(
|
||||
id_value: str | int, error_message: str, response_type: str = "error"
|
||||
) -> dict:
|
||||
"""
|
||||
Create a JSON-RPC error response with either error object or content format.
|
||||
|
||||
Args:
|
||||
id_value: The ID of the JSON-RPC request
|
||||
error_message: The error message to include
|
||||
response_type: Either "error" or "content" to determine response format
|
||||
|
||||
Returns:
|
||||
A properly formatted JSON-RPC response dictionary
|
||||
"""
|
||||
base_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": id_value,
|
||||
}
|
||||
|
||||
if response_type == "error":
|
||||
base_response["error"] = {
|
||||
"code": -32600,
|
||||
"message": error_message,
|
||||
}
|
||||
else:
|
||||
base_response["result"] = {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": error_message,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return base_response
|
||||
|
||||
|
||||
async def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
|
||||
"""
|
||||
Hook function to intercept tool calls.
|
||||
|
||||
If the request is blocked, it returns a tuple with a message explaining the block
|
||||
and a flag indicating the request was blocked.
|
||||
Otherwise it returns the original request and a flag indicating it was not blocked.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request["params"]["name"],
|
||||
"arguments": request["params"]["arguments"],
|
||||
},
|
||||
}
|
||||
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
if (
|
||||
guardrailing_result
|
||||
and guardrailing_result.get("errors", [])
|
||||
and check_if_new_errors(ctx, guardrailing_result)
|
||||
):
|
||||
if ctx.push_explorer:
|
||||
await append_and_push_trace(ctx, message, guardrailing_result)
|
||||
else:
|
||||
ctx.trace.append(message)
|
||||
|
||||
return json_rpc_error_response(
|
||||
request.get("id"),
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrailing_result["errors"],
|
||||
response_type=ctx.failure_response_format,
|
||||
), True
|
||||
|
||||
# Add the message to the trace
|
||||
ctx.trace.append(message)
|
||||
return request, False
|
||||
|
||||
|
||||
async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
"""
|
||||
Hook function to intercept tool results.
|
||||
Returns the potentially modified result.
|
||||
"""
|
||||
method = ctx.id_to_method_mapping.get(result.get("id"))
|
||||
call_id = f"call_{result.get('id')}"
|
||||
|
||||
# Safely handle result object
|
||||
result_obj = result.get("result", {})
|
||||
if result_obj.get(MCP_SERVER_INFO):
|
||||
ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "")
|
||||
|
||||
if not method:
|
||||
return result
|
||||
elif method == MCP_TOOL_CALL:
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": result_obj.get("content"),
|
||||
"error": result_obj.get("error"),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if guardrailing_result and guardrailing_result.get("errors", []):
|
||||
result = json_rpc_error_response(
|
||||
result.get("id"),
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% format_errors_in_response(guardrailing_result["errors"]),
|
||||
response_type=ctx.failure_response_format, # Using content type as that's what the original code used
|
||||
)
|
||||
|
||||
if ctx.push_explorer:
|
||||
await append_and_push_trace(ctx, message, guardrailing_result)
|
||||
else:
|
||||
ctx.trace.append(message)
|
||||
|
||||
return result
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
ctx.tools = result_obj.get("tools")
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": json.dumps(result.get("result").get("tools")),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
# next validate it with guardrails
|
||||
guardrailing_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
if guardrailing_result and guardrailing_result.get("errors", []):
|
||||
result["result"]["tools"] = [
|
||||
{
|
||||
"name": "blocked_" + tool["name"],
|
||||
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
|
||||
% format_errors_in_response(guardrailing_result["errors"]),
|
||||
# no parameters
|
||||
"inputSchema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "invariant_mcp_server_blockedArguments",
|
||||
"type": "object",
|
||||
},
|
||||
"annotations": {
|
||||
"title": "This tool was blocked by security guardrails.",
|
||||
},
|
||||
}
|
||||
for tool in result["result"]["tools"]
|
||||
]
|
||||
|
||||
# add it to the session trace (and run logging guardrails)
|
||||
if ctx.push_explorer:
|
||||
await append_and_push_trace(ctx, message, guardrailing_result)
|
||||
else:
|
||||
ctx.trace.append(message)
|
||||
|
||||
return result
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
async def stream_and_forward_stdout(
|
||||
mcp_process: subprocess.Popen, ctx: McpContext
|
||||
session_id: str, mcp_process: subprocess.Popen
|
||||
) -> None:
|
||||
"""Read from the mcp_process stdout, apply guardrails and forward to sys.stdout"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while True:
|
||||
if mcp_process.poll() is not None:
|
||||
mcp_log(f"[ERROR] MCP process terminated with code: {mcp_process.poll()}")
|
||||
break
|
||||
|
||||
line = await loop.run_in_executor(None, mcp_process.stdout.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
try:
|
||||
# Process complete JSON lines
|
||||
line_str = line.decode(UTF_8).strip()
|
||||
if not line_str:
|
||||
decoded_line = line.decode(UTF_8).strip()
|
||||
if not decoded_line:
|
||||
continue
|
||||
session = session_store.get_session(session_id)
|
||||
if session.attributes.verbose:
|
||||
mcp_log(f"[INFO] server -> client: {decoded_line}")
|
||||
response_body = json.loads(decoded_line)
|
||||
update_mcp_server_in_session_metadata(session, response_body)
|
||||
|
||||
if ctx.verbose:
|
||||
mcp_log(f"[INFO] server -> client: {line_str}")
|
||||
|
||||
parsed_json = json.loads(line_str)
|
||||
processed_json = await hook_tool_result(ctx, parsed_json)
|
||||
|
||||
intercept_response_result, _ = await intercept_response(
|
||||
session_id, session_store, response_body
|
||||
)
|
||||
# Write and flush immediately
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(intercept_response_result))
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
|
||||
if line:
|
||||
@@ -407,74 +91,52 @@ async def stream_and_forward_stderr(
|
||||
MCP_LOG_FILE.buffer.flush()
|
||||
|
||||
|
||||
async def process_line(
|
||||
ctx: McpContext, mcp_process: subprocess.Popen, line: bytes
|
||||
async def _intercept_request(
|
||||
session_id: str, mcp_process: subprocess.Popen, line: bytes
|
||||
) -> None:
|
||||
"""Process a line of input from stdin, decode it, and forward to mcp_process."""
|
||||
if ctx.verbose:
|
||||
"""
|
||||
Process a line of input from stdin, decode it and check for guardrails.
|
||||
If the request is blocked, it returns a message indicating the block reason
|
||||
otherwise it forwards the request to mcp_process stdin.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
if session.attributes.verbose:
|
||||
mcp_log(f"[INFO] client -> server: {line}")
|
||||
|
||||
# Try to decode and parse as JSON to check for tool calls
|
||||
try:
|
||||
text = line.decode(UTF_8)
|
||||
parsed_json = json.loads(text)
|
||||
request_body = json.loads(text)
|
||||
except json.JSONDecodeError as je:
|
||||
mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}")
|
||||
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
|
||||
return
|
||||
update_session_from_request(session, request_body)
|
||||
# Refresh guardrails
|
||||
await session.load_guardrails()
|
||||
|
||||
if parsed_json.get(MCP_METHOD) is not None:
|
||||
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(MCP_METHOD)
|
||||
if parsed_json.get(MCP_PARAMS) and parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO):
|
||||
ctx.mcp_client_name = (
|
||||
parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
hook_tool_call_result = {}
|
||||
is_blocked = False
|
||||
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id, session_store, request_body
|
||||
)
|
||||
|
||||
# Check if this is a tool call request
|
||||
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Refresh guardrails
|
||||
run_task_sync(ctx.load_guardrails)
|
||||
|
||||
# Intercept and potentially block modify the request
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(ctx, parsed_json)
|
||||
if not is_blocked:
|
||||
# If blocked, hook_tool_call_result contains the original request.
|
||||
# Forward the request to the MCP process.
|
||||
# It will handle the request and return a response.
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(hook_tool_call_result))
|
||||
mcp_process.stdin.flush()
|
||||
else:
|
||||
# If blocked, hook_tool_call_result contains the block message.
|
||||
# Forward the block message result back to the caller.
|
||||
# The original request is not passed to the MCP process.
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
|
||||
sys.stdout.buffer.flush()
|
||||
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
request_body={
|
||||
"id": request_body.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
|
||||
},
|
||||
)
|
||||
if is_blocked:
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
|
||||
sys.stdout.buffer.flush()
|
||||
return
|
||||
else:
|
||||
# pass through the request to the MCP process
|
||||
|
||||
# for list_tools, extend the trace by a tool call
|
||||
if parsed_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
# Refresh guardrails
|
||||
run_task_sync(ctx.load_guardrails)
|
||||
ctx.trace.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{parsed_json.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tools/list",
|
||||
"arguments": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
|
||||
mcp_process.stdin.flush()
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(request_body))
|
||||
mcp_process.stdin.flush()
|
||||
|
||||
|
||||
async def wait_for_stdin_input(
|
||||
@@ -523,7 +185,7 @@ async def wait_for_stdin_input(
|
||||
|
||||
|
||||
async def run_stdio_input_loop(
|
||||
ctx: McpContext,
|
||||
session_id: str,
|
||||
mcp_process: subprocess.Popen,
|
||||
stdout_task: asyncio.Task,
|
||||
stderr_task: asyncio.Task,
|
||||
@@ -557,7 +219,7 @@ async def run_stdio_input_loop(
|
||||
if not line:
|
||||
continue
|
||||
|
||||
await process_line(ctx, mcp_process, line)
|
||||
await _intercept_request(session_id, mcp_process, line)
|
||||
except (BrokenPipeError, KeyboardInterrupt):
|
||||
# Broken pipe = client disappeared, just start shutdown
|
||||
mcp_log("Client disconnected or keyboard interrupt")
|
||||
@@ -570,7 +232,7 @@ async def run_stdio_input_loop(
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if line:
|
||||
await process_line(ctx, mcp_process, line)
|
||||
await _intercept_request(session_id, mcp_process, line)
|
||||
|
||||
# Terminate process if needed
|
||||
if mcp_process.poll() is None:
|
||||
@@ -631,7 +293,11 @@ async def execute(args: list[str] = None):
|
||||
mcp_log("[INFO] Running with Python version:", sys.version)
|
||||
|
||||
mcp_gateway_args, mcp_server_command_args = split_args(args)
|
||||
ctx = McpContext(mcp_gateway_args)
|
||||
session_id = generate_session_id()
|
||||
await session_store.initialize_session(
|
||||
session_id,
|
||||
McpAttributes.from_cli_args(mcp_gateway_args),
|
||||
)
|
||||
|
||||
mcp_process = subprocess.Popen(
|
||||
mcp_server_command_args,
|
||||
@@ -642,8 +308,10 @@ async def execute(args: list[str] = None):
|
||||
)
|
||||
|
||||
# Start async tasks for stdout and stderr
|
||||
stdout_task = asyncio.create_task(stream_and_forward_stdout(mcp_process, ctx))
|
||||
stdout_task = asyncio.create_task(
|
||||
stream_and_forward_stdout(session_id, mcp_process)
|
||||
)
|
||||
stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process))
|
||||
|
||||
# Handle forwarding stdin and intercept tool calls
|
||||
await run_stdio_input_loop(ctx, mcp_process, stdout_task, stderr_task)
|
||||
await run_stdio_input_loop(session_id, mcp_process, stdout_task, stderr_task)
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
"""Context manager for MCP (Model Context Protocol) gateway."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from gateway.integrations.explorer import (
|
||||
fetch_guardrails_from_explorer,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailRuleSet
|
||||
|
||||
|
||||
class McpContext:
|
||||
"""Singleton class to manage MCP context and state."""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(McpContext, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, cli_args: list):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._initialized = False
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
config, extra_args = self._parse_cli_args(cli_args)
|
||||
|
||||
# The project name is used to identify the dataset in Invariant Explorer.
|
||||
self.explorer_dataset = config.project_name
|
||||
# whether to push traces to Invariant Explorer
|
||||
self.push_explorer = config.push_explorer
|
||||
# the format to use to communicate guardrail failures to the client
|
||||
self.failure_response_format = config.failure_response_format
|
||||
# verbose logging of in/out
|
||||
self.verbose = config.verbose
|
||||
|
||||
# trace of this MCP session
|
||||
self.trace = []
|
||||
# tools available to the MCP server
|
||||
self.tools = []
|
||||
|
||||
# configured guardrails
|
||||
self.guardrails = GuardrailRuleSet(
|
||||
blocking_guardrails=[], logging_guardrails=[]
|
||||
)
|
||||
|
||||
# parsed from CLI (all --metadata-* args)
|
||||
self.extra_metadata: Dict[str, str] = {}
|
||||
for arg in extra_args:
|
||||
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
|
||||
key, value = arg.split("=")
|
||||
assert key.startswith(
|
||||
"--metadata-"
|
||||
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
|
||||
key = key[len("--metadata-") :]
|
||||
self.extra_metadata[key] = value
|
||||
|
||||
# captured from MCP calls/responses
|
||||
self.mcp_client_name = ""
|
||||
self.mcp_server_name = ""
|
||||
|
||||
# We send the same trace messages for guardrails analysis multiple times.
|
||||
# We need to deduplicate them before sending to the explorer.
|
||||
self.annotations = []
|
||||
self.trace_id = None
|
||||
self.local_session_id = str(uuid.uuid4())
|
||||
self.last_trace_length = 0
|
||||
self.id_to_method_mapping = {}
|
||||
self._initialized = True
|
||||
|
||||
def _parse_cli_args(self, cli_args: list) -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="MCP Gateway")
|
||||
parser.add_argument(
|
||||
"--project-name",
|
||||
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
|
||||
type=str,
|
||||
default=f"mcp-capture-{random.randint(1, 100)}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-explorer",
|
||||
help="Enable pushing traces to Invariant Explorer",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
help="Enable verbose logging",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--failure-response-format",
|
||||
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
|
||||
type=str,
|
||||
default="error",
|
||||
)
|
||||
|
||||
return parser.parse_known_args(cli_args)
|
||||
|
||||
async def load_guardrails(self):
|
||||
"""Run async setup logic (e.g. fetching guardrails)."""
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.explorer_dataset,
|
||||
"Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
self.extra_metadata.get("client", self.mcp_client_name),
|
||||
self.extra_metadata.get("server", self.mcp_server_name),
|
||||
)
|
||||
@@ -15,10 +15,6 @@ from gateway.common.constants import (
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
MCP_PARAMS,
|
||||
MCP_RESULT,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_CLIENT_INFO,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
@@ -28,7 +24,9 @@ from gateway.common.mcp_sessions_manager import (
|
||||
from gateway.common.mcp_utils import (
|
||||
get_mcp_server_base_url,
|
||||
hook_tool_call,
|
||||
hook_tool_call_response,
|
||||
intercept_response,
|
||||
update_mcp_server_in_session_metadata,
|
||||
update_session_from_request,
|
||||
)
|
||||
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
@@ -72,16 +70,7 @@ async def mcp_post_sse_gateway(
|
||||
request_body_bytes = await request.body()
|
||||
request_body = json.loads(request_body_bytes)
|
||||
session = session_store.get_session(session_id)
|
||||
if request_body.get(MCP_METHOD) and request_body.get("id"):
|
||||
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.attributes.metadata["mcp_client"] = (
|
||||
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
update_session_from_request(session, request_body)
|
||||
|
||||
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Intercept and potentially block the request
|
||||
@@ -137,8 +126,6 @@ async def mcp_post_sse_gateway(
|
||||
print(f"[MCP POST] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"[MCP POST] Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
@@ -340,59 +327,18 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8)
|
||||
session = session_store.get_session(session_id)
|
||||
try:
|
||||
response_json = json.loads(sse.data)
|
||||
response_body = json.loads(sse.data)
|
||||
update_mcp_server_in_session_metadata(session, response_body)
|
||||
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.attributes.metadata["mcp_server"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
intercept_response_result, is_blocked = await intercept_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body=response_body,
|
||||
)
|
||||
if is_blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(intercept_response_result)}\n\n".encode(
|
||||
UTF_8
|
||||
)
|
||||
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
if method == MCP_TOOL_CALL:
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json=response_json,
|
||||
)
|
||||
# Update the event bytes with hook_tool_call_response.
|
||||
# hook_tool_call_response is same as response_json if no guardrail is violated.
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = (
|
||||
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json={
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"content": json.dumps(
|
||||
response_json.get(MCP_RESULT).get("tools")
|
||||
),
|
||||
"tools": response_json.get(MCP_RESULT).get("tools"),
|
||||
},
|
||||
},
|
||||
is_tools_list=True,
|
||||
)
|
||||
# Update the event bytes with hook_tool_call_response.
|
||||
# hook_tool_call_response is same as response_json if no guardrail is violated.
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = (
|
||||
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"[MCP SSE] Error parsing message JSON: {e}",
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -13,12 +10,8 @@ from fastapi.responses import StreamingResponse
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
INVARIANT_SESSION_ID_PREFIX,
|
||||
MCP_CLIENT_INFO,
|
||||
MCP_LIST_TOOLS,
|
||||
MCP_METHOD,
|
||||
MCP_PARAMS,
|
||||
MCP_RESULT,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_TOOL_CALL,
|
||||
UTF_8,
|
||||
)
|
||||
@@ -27,9 +20,13 @@ from gateway.common.mcp_sessions_manager import (
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
generate_session_id,
|
||||
get_mcp_server_base_url,
|
||||
hook_tool_call,
|
||||
hook_tool_call_response,
|
||||
intercept_response,
|
||||
update_mcp_client_info_in_session,
|
||||
update_mcp_server_in_session_metadata,
|
||||
update_tool_call_id_in_session,
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
@@ -69,7 +66,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
|
||||
# If a session ID is provided in the request headers, it was already initialized
|
||||
# in McpSessionsManager. This might be a session ID returned by the MCP server
|
||||
# or a session ID generated in the gateway.
|
||||
_update_tool_call_id_in_session(session_id, request_body)
|
||||
update_tool_call_id_in_session(
|
||||
session_store.get_session(session_id), request_body
|
||||
)
|
||||
elif is_initialization_request:
|
||||
# If this is an initialization request, we generate a session ID,
|
||||
# We don't call initialize_session here because we don't know
|
||||
@@ -77,7 +76,7 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
|
||||
# If later in the response from MCP server, we don't receive a session ID then this
|
||||
# will be initialized and returned back to the client else this will be
|
||||
# overwritten by the session ID returned by the MCP server.
|
||||
session_id = _generate_session_id()
|
||||
session_id = generate_session_id()
|
||||
|
||||
# Intercept the request and check for guardrails.
|
||||
if not is_initialization_request:
|
||||
@@ -115,9 +114,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
|
||||
|
||||
# Update client info if this is an initialization request
|
||||
if is_initialization_request:
|
||||
_update_mcp_client_info_in_session(
|
||||
session_id=session_id,
|
||||
request_body=request_body,
|
||||
update_mcp_client_info_in_session(
|
||||
session_store.get_session(session_id),
|
||||
request_body,
|
||||
)
|
||||
|
||||
# If the response is JSON type, handle it as a JSON response.
|
||||
@@ -281,52 +280,17 @@ def _get_mcp_server_endpoint(request: Request) -> str:
|
||||
return get_mcp_server_base_url(request) + "/mcp/"
|
||||
|
||||
|
||||
def _generate_session_id() -> str:
|
||||
"""
|
||||
Generate a new session ID.
|
||||
If the MCP server is session less then we don't have a session ID from the MCP server.
|
||||
"""
|
||||
return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex
|
||||
|
||||
|
||||
def _update_tool_call_id_in_session(session_id: str, request_body: dict) -> None:
|
||||
"""
|
||||
Updates the tool call ID in the session.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
if request_body.get(MCP_METHOD) and request_body.get("id"):
|
||||
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
|
||||
|
||||
def _update_mcp_client_info_in_session(session_id: str, request_body: dict) -> None:
|
||||
"""
|
||||
Update the MCP client info in the session metadata.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.attributes.metadata["mcp_client"] = (
|
||||
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
|
||||
def _update_mcp_response_info_in_session(
|
||||
session_id: str, response_json: dict, is_json_response: bool
|
||||
session_id: str, response_body: dict, is_json_response: bool
|
||||
) -> None:
|
||||
"""
|
||||
Update the MCP response info in the session metadata.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.attributes.metadata["mcp_server"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
session.attributes.metadata["server_response_type"] = "json" if is_json_response else "sse"
|
||||
update_mcp_server_in_session_metadata(session, response_body)
|
||||
session.attributes.metadata["server_response_type"] = (
|
||||
"json" if is_json_response else "sse"
|
||||
)
|
||||
|
||||
|
||||
def _is_initialization_request(request_body: dict) -> bool:
|
||||
@@ -353,18 +317,20 @@ async def _handle_mcp_json_response(
|
||||
# return the error message else return the response as is
|
||||
response_content = response.content
|
||||
# The server response is empty string when client sends "notifications/initialized"
|
||||
response_json = (
|
||||
response_body = (
|
||||
json.loads(response_content.decode(UTF_8)) if response_content else {}
|
||||
)
|
||||
if response_json:
|
||||
if response_body:
|
||||
_update_mcp_response_info_in_session(
|
||||
session_id=session_id, response_json=response_json, is_json_response=True
|
||||
session_id=session_id, response_body=response_body, is_json_response=True
|
||||
)
|
||||
response_code = response.status_code
|
||||
|
||||
if not is_initialization_request:
|
||||
intercept_response_result, blocked = await _intercept_response(
|
||||
session_id=session_id, response_json=response_json
|
||||
intercept_response_result, blocked = await intercept_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body=response_body,
|
||||
)
|
||||
if blocked:
|
||||
response_content = json.dumps(intercept_response_result).encode(UTF_8)
|
||||
@@ -405,14 +371,15 @@ async def _handle_mcp_streaming_response(
|
||||
if not stripped_line:
|
||||
break # End of stream
|
||||
if buffer:
|
||||
response_json = json.loads(stripped_line.split("data: ")[1].strip())
|
||||
response_body = json.loads(stripped_line.split("data: ")[1].strip())
|
||||
if not is_initialization_request:
|
||||
(
|
||||
intercept_response_result,
|
||||
blocked,
|
||||
) = await _intercept_response(
|
||||
) = await intercept_response(
|
||||
session_id=session_id,
|
||||
response_json=response_json,
|
||||
session_store=session_store,
|
||||
response_body=response_body,
|
||||
)
|
||||
if blocked:
|
||||
yield (
|
||||
@@ -423,7 +390,7 @@ async def _handle_mcp_streaming_response(
|
||||
else:
|
||||
_update_mcp_response_info_in_session(
|
||||
session_id=session_id,
|
||||
response_json=response_json,
|
||||
response_body=response_body,
|
||||
is_json_response=False,
|
||||
)
|
||||
yield f"{buffer}\n{stripped_line}\n\n"
|
||||
@@ -482,46 +449,3 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response |
|
||||
media_type="application/json",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _intercept_response(
|
||||
session_id: str, response_json: dict
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Intercept the response and check for guardrails.
|
||||
This function is used to intercept responses and check for guardrails.
|
||||
If the response is blocked, it returns a message indicating the block
|
||||
reason with a boolean flag set to True. If the response is not blocked,
|
||||
it returns the original response with a boolean flag set to False.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
# Intercept and potentially block tool call response
|
||||
if method == MCP_TOOL_CALL:
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json=response_json,
|
||||
)
|
||||
return result, blocked
|
||||
# Intercept and potentially block list tool call response
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json={
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"content": json.dumps(response_json.get(MCP_RESULT).get("tools")),
|
||||
"tools": response_json.get(MCP_RESULT).get("tools"),
|
||||
},
|
||||
},
|
||||
is_tools_list=True,
|
||||
)
|
||||
return result, blocked
|
||||
return response_json, False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""A MCP client implementation that interacts with MCP server to make tool calls."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from datetime import timedelta
|
||||
@@ -48,7 +49,6 @@ class MCPClient:
|
||||
for key, value in metadata_keys.items():
|
||||
args.append("--metadata-" + key + "=" + value)
|
||||
|
||||
|
||||
if push_to_explorer:
|
||||
args.append("--push-explorer")
|
||||
args.extend(
|
||||
@@ -133,7 +133,7 @@ async def run(
|
||||
project_name,
|
||||
server_script_path,
|
||||
push_to_explorer,
|
||||
metadata_keys=metadata_keys
|
||||
metadata_keys=metadata_keys,
|
||||
)
|
||||
listed_tools = await client.session.list_tools()
|
||||
if tool_name == "tools/list":
|
||||
@@ -142,4 +142,6 @@ async def run(
|
||||
else:
|
||||
return await client.call_tool(tool_name, tool_args)
|
||||
finally:
|
||||
if push_to_explorer:
|
||||
await asyncio.sleep(2)
|
||||
await client.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user