diff --git a/gateway/common/constants.py b/gateway/common/constants.py index 39010c9..00cfa93 100644 --- a/gateway/common/constants.py +++ b/gateway/common/constants.py @@ -1,5 +1,7 @@ """Common constants used in the gateway.""" +DEFAULT_API_URL = "https://explorer.invariantlabs.ai" + IGNORED_HEADERS = [ "accept-encoding", "host", diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py index 14a627f..0a0e0f7 100644 --- a/gateway/common/mcp_sessions_manager.py +++ b/gateway/common/mcp_sessions_manager.py @@ -16,7 +16,7 @@ from invariant_sdk.types.push_traces import PushTracesRequest from pydantic import BaseModel, Field, PrivateAttr from starlette.datastructures import Headers -from gateway.common.constants import INVARIANT_SESSION_ID_PREFIX +from gateway.common.constants import DEFAULT_API_URL, INVARIANT_SESSION_ID_PREFIX from gateway.common.guardrails import GuardrailRuleSet, GuardrailAction from gateway.common.request_context import RequestContext from gateway.integrations.explorer import ( @@ -25,8 +25,6 @@ from gateway.integrations.explorer import ( ) from gateway.integrations.guardrails import check_guardrails -DEFAULT_API_URL = "https://explorer.invariantlabs.ai" - def user_and_host() -> str: """Get the current user and hostname.""" @@ -61,30 +59,15 @@ class McpSession(BaseModel): # and other session-related operations _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) - def get_invariant_api_key(self) -> str: + def _get_invariant_api_key(self) -> str: """Get the Invariant API key for the session.""" if self.attributes.invariant_api_key: return self.attributes.invariant_api_key return os.getenv("INVARIANT_API_KEY") - def get_invariant_authorization(self) -> str: + def _get_invariant_authorization(self) -> str: """Get the Invariant authorization header for the session.""" - return "Bearer " + self.get_invariant_api_key() - - async def load_guardrails(self) -> None: - """ - Load guardrails for the session. - - This method fetches guardrails from the Invariant Explorer and assigns them to the session. - """ - print("Inside load_guardrails attributes: ", self.attributes, flush=True) - self.guardrails = await fetch_guardrails_from_explorer( - self.attributes.explorer_dataset, - self.get_invariant_authorization(), - # pylint: disable=no-member - self.attributes.metadata.get("mcp_client"), - self.attributes.metadata.get("mcp_server"), - ) + return "Bearer " + self._get_invariant_api_key() def _deduplicate_annotations(self, new_annotations: list) -> list: """Deduplicate new_annotations using the annotations in the session.""" @@ -94,6 +77,20 @@ class McpSession(BaseModel): deduped_annotations.append(annotation) return deduped_annotations + async def load_guardrails(self) -> None: + """ + Load guardrails for the session. + + This method fetches guardrails from the Invariant Explorer and assigns them to the session. + """ + self.guardrails = await fetch_guardrails_from_explorer( + self.attributes.explorer_dataset, + self._get_invariant_authorization(), + # pylint: disable=no-member + self.attributes.metadata.get("mcp_client"), + self.attributes.metadata.get("mcp_server"), + ) + @contextlib.asynccontextmanager async def session_lock(self): """ @@ -134,15 +131,10 @@ class McpSession(BaseModel): return {} # Prepare context and select appropriate guardrails - print( - "Inside get_guardrails_check_result attributes: ", - self.attributes, - flush=True, - ) context = RequestContext.create( request_json={}, dataset_name=self.attributes.explorer_dataset, - invariant_authorization=self.get_invariant_authorization(), + invariant_authorization=self._get_invariant_authorization(), guardrails=self.guardrails, guardrails_parameters={ "metadata": self.session_metadata(), @@ -210,11 +202,10 @@ class McpSession(BaseModel): This is an internal method that should only be called within a lock. """ - print("Inside _push_trace_update attributes: ", self.attributes, flush=True) try: client = AsyncClient( api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL), - api_key=self.get_invariant_api_key(), + api_key=self._get_invariant_api_key(), ) # If no trace exists, create a new one @@ -247,7 +238,7 @@ class McpSession(BaseModel): self.annotations.extend(deduplicated_annotations) self.last_trace_length = len(self.messages) except Exception as e: # pylint: disable=broad-except - print(f"[MCP SSE] Error pushing trace for session {self.session_id}: {e}") + print(f"[MCP] Error pushing trace for session {self.session_id}: {e}") async def add_pending_error_message(self, error_message: dict) -> None: """ diff --git a/gateway/common/mcp_utils.py b/gateway/common/mcp_utils.py index c6461e9..208ad19 100644 --- a/gateway/common/mcp_utils.py +++ b/gateway/common/mcp_utils.py @@ -1,7 +1,9 @@ """MCP utility functions.""" import asyncio +import json import re +import uuid from typing import Tuple @@ -9,18 +11,83 @@ from fastapi import Request, HTTPException from gateway.common.constants import ( INVARIANT_GUARDRAILS_BLOCKED_MESSAGE, INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE, + INVARIANT_SESSION_ID_PREFIX, + MCP_CLIENT_INFO, MCP_SERVER_BASE_URL_HEADER, + MCP_LIST_TOOLS, + MCP_METHOD, MCP_PARAMS, MCP_RESULT, + MCP_SERVER_INFO, + MCP_TOOL_CALL, ) from gateway.common.guardrails import GuardrailAction from gateway.common.mcp_sessions_manager import ( + McpSession, McpSessionsManager, ) from gateway.integrations.explorer import create_annotations_from_guardrails_errors from gateway.mcp.log import format_errors_in_response +def _check_if_new_errors( + session_id: str, session_store: McpSessionsManager, guardrails_result: dict +) -> bool: + """Checks if there are new errors in the guardrails result.""" + session = session_store.get_session(session_id) + annotations = create_annotations_from_guardrails_errors( + guardrails_result.get("errors", []) + ) + for annotation in annotations: + if annotation not in session.annotations: + return True + return False + + +def generate_session_id() -> str: + """ + Generate a new session ID. + If the MCP server is session less then we don't have a session ID from the MCP server. + """ + return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex + + +def update_mcp_server_in_session_metadata( + session: McpSession, response_body: dict +) -> None: + """Update the MCP server information in the session metadata.""" + if response_body.get(MCP_RESULT) and response_body.get(MCP_RESULT).get( + MCP_SERVER_INFO + ): + session.attributes.metadata["mcp_server"] = ( + response_body.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") + ) + + +def update_tool_call_id_in_session(session: McpSession, request_body: dict) -> None: + """Updates the tool call ID in the session.""" + if request_body.get(MCP_METHOD) and request_body.get("id"): + session.id_to_method_mapping[request_body.get("id")] = request_body.get( + MCP_METHOD + ) + + +def update_mcp_client_info_in_session(session: McpSession, request_body: dict) -> None: + """Update the MCP client info in the session metadata.""" + if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( + MCP_CLIENT_INFO + ): + session.attributes.metadata["mcp_client"] = ( + request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") + ) + + +def update_session_from_request(session: McpSession, request_body: dict) -> None: + """Update the MCP client information and request id in the session.""" + update_mcp_client_info_in_session(session, request_body) + update_tool_call_id_in_session(session, request_body) + + def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: """ Convert localhost or 127.0.0.1 in an address to host.docker.internal @@ -73,7 +140,13 @@ async def hook_tool_call( Args: session_id (str): The session ID associated with the request. + session_store (McpSessionsManager): The session store to manage sessions. request_body (dict): The request JSON to be processed. + + Returns: + Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean + indicating whether the request was blocked. If the request is blocked, the + dict will contain an error message else it will contain the original request. """ tool_call = { "id": f"call_{request_body.get('id')}", @@ -89,14 +162,13 @@ async def hook_tool_call( guardrails_result = await session.get_guardrails_check_result( message, action=GuardrailAction.BLOCK ) - print("[hook_tool_call] Guardrails result:", guardrails_result, flush=True) # If the request is blocked, return a message indicating the block reason. # If there are new errors, run append_and_push_trace in background. # If there are no new errors, just return the original request. if ( guardrails_result and guardrails_result.get("errors", []) - and check_if_new_errors(session_id, session_store, guardrails_result) + and _check_if_new_errors(session_id, session_store, guardrails_result) ): # Add the trace to the explorer asyncio.create_task( @@ -120,24 +192,10 @@ async def hook_tool_call( return request_body, False -def check_if_new_errors( - session_id: str, session_store: McpSessionsManager, guardrails_result: dict -) -> bool: - """Checks if there are new errors in the guardrails result.""" - session = session_store.get_session(session_id) - annotations = create_annotations_from_guardrails_errors( - guardrails_result.get("errors", []) - ) - for annotation in annotations: - if annotation not in session.annotations: - return True - return False - - async def hook_tool_call_response( session_id: str, session_store: McpSessionsManager, - response_json: dict, + response_body: dict, is_tools_list=False, ) -> dict: """ @@ -145,19 +203,21 @@ async def hook_tool_call_response( Hook to process the response JSON after receiving it from the MCP server. Args: session_id (str): The session ID associated with the request. - response_json (dict): The response JSON to be processed. + session_store (McpSessionsManager): The session store to manage sessions. + response_body (dict): The response JSON to be processed. + is_tools_list (bool): Flag to indicate if the response is from a tools/list call. Returns: dict: The response JSON is returned if no guardrail is violated else an error dict is returned. """ - blocked = False + is_blocked = False + result = response_body message = { "role": "tool", - "tool_call_id": f"call_{response_json.get('id')}", - "content": response_json.get(MCP_RESULT).get("content"), - "error": response_json.get(MCP_RESULT).get("error"), + "tool_call_id": f"call_{result.get('id')}", + "content": result.get(MCP_RESULT).get("content"), + "error": result.get(MCP_RESULT).get("error"), } - result = response_json session = session_store.get_session(session_id) guardrails_result = await session.get_guardrails_check_result( message, action=GuardrailAction.BLOCK @@ -166,14 +226,14 @@ async def hook_tool_call_response( if ( guardrails_result and guardrails_result.get("errors", []) - and check_if_new_errors(session_id, session_store, guardrails_result) + and _check_if_new_errors(session_id, session_store, guardrails_result) ): - blocked = True + is_blocked = True # If the request is blocked, return a message indicating the block reason if not is_tools_list: result = { "jsonrpc": "2.0", - "id": response_json.get("id"), + "id": response_body.get("id"), "error": { "code": -32600, "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE @@ -184,7 +244,7 @@ async def hook_tool_call_response( # special error response for tools/list tool call result = { "jsonrpc": "2.0", - "id": response_json.get("id"), + "id": response_body.get("id"), "result": { "tools": [ { @@ -201,13 +261,64 @@ async def hook_tool_call_response( "title": "This tool was blocked by security guardrails.", }, } - for tool in response_json["result"]["tools"] + for tool in response_body["result"]["tools"] ] }, } - # Push trace to the explorer - don't block on its response - asyncio.create_task( - session_store.add_message_to_session(session_id, message, guardrails_result) - ) - return result, blocked + # Push trace to the explorer + await session_store.add_message_to_session(session_id, message, guardrails_result) + return result, is_blocked + + +async def intercept_response( + session_id: str, session_store: McpSessionsManager, response_body: dict +) -> Tuple[dict, bool]: + """ + Intercept the response and check for guardrails. + This function is used to intercept responses and check for guardrails. + If the response is blocked, it returns a message indicating the block + reason with a boolean flag set to True. If the response is not blocked, + it returns the original response with a boolean flag set to False. + + Args: + session_id (str): The session ID associated with the request. + session_store (McpSessionsManager): The session store to manage sessions. + response_body (dict): The response JSON to be processed. + + Returns: + Tuple[dict, bool]: A tuple containing the processed response JSON + and a boolean indicating whether the response was blocked. + """ + session = session_store.get_session(session_id) + method = session.id_to_method_mapping.get(response_body.get("id")) + + intercept_response_result = response_body + is_blocked = False + # Intercept and potentially block tool call response + if method == MCP_TOOL_CALL: + intercept_response_result, is_blocked = await hook_tool_call_response( + session_id=session_id, + session_store=session_store, + response_body=response_body, + ) + # Intercept and potentially block list tool call response + elif method == MCP_LIST_TOOLS: + # store tools in metadata + session_store.get_session(session_id).attributes.metadata["tools"] = ( + response_body.get(MCP_RESULT).get("tools") + ) + intercept_response_result, is_blocked = await hook_tool_call_response( + session_id=session_id, + session_store=session_store, + response_body={ + "jsonrpc": "2.0", + "id": response_body.get("id"), + "result": { + "content": json.dumps(response_body.get(MCP_RESULT).get("tools")), + "tools": response_body.get(MCP_RESULT).get("tools"), + }, + }, + is_tools_list=True, + ) + return intercept_response_result, is_blocked diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index e2b2a77..e547831 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -6,6 +6,7 @@ import json from typing import Any, Dict, List from fastapi import HTTPException +from gateway.common.constants import DEFAULT_API_URL from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse @@ -13,8 +14,6 @@ from invariant_sdk.types.annotations import AnnotationCreate import httpx -DEFAULT_API_URL = "https://explorer.invariantlabs.ai" - def create_annotations_from_guardrails_errors( guardrails_errors: List[dict], diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 0fed353..3424881 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -9,14 +9,13 @@ from functools import wraps from fastapi import HTTPException import httpx +from gateway.common.constants import DEFAULT_API_URL from gateway.common.guardrails import Guardrail from gateway.common.request_context import RequestContext from gateway.common.authorization import ( INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER, ) -DEFAULT_API_URL = "https://explorer.invariantlabs.ai" - # Timestamps of last API calls per guardrails string _guardrails_cache = {} diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py index 865d11c..4f262c8 100644 --- a/gateway/mcp/mcp.py +++ b/gateway/mcp/mcp.py @@ -1,63 +1,36 @@ """Gateway for MCP (Model Context Protocol) integration with Invariant.""" import asyncio -import getpass import json import os import platform import select -import socket import subprocess import sys -from invariant_sdk.async_client import AsyncClient -from invariant_sdk.types.append_messages import AppendMessagesRequest -from invariant_sdk.types.push_traces import PushTracesRequest - from gateway.common.constants import ( - INVARIANT_GUARDRAILS_BLOCKED_MESSAGE, - INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE, MCP_METHOD, - MCP_CLIENT_INFO, - MCP_PARAMS, - MCP_SERVER_INFO, MCP_TOOL_CALL, MCP_LIST_TOOLS, UTF_8, ) -from gateway.common.guardrails import GuardrailAction -from gateway.common.request_context import RequestContext -from gateway.integrations.explorer import create_annotations_from_guardrails_errors -from gateway.integrations.guardrails import check_guardrails -from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response -from gateway.mcp.mcp_context import McpContext -from gateway.mcp.task_utils import run_task_sync +from gateway.common.mcp_sessions_manager import ( + McpAttributes, + McpSessionsManager, +) +from gateway.common.mcp_utils import ( + generate_session_id, + hook_tool_call, + intercept_response, + update_mcp_server_in_session_metadata, + update_session_from_request, +) +from gateway.mcp.log import mcp_log, MCP_LOG_FILE - -DEFAULT_API_URL = "https://explorer.invariantlabs.ai" STATUS_EOF = "eof" STATUS_DATA = "data" STATUS_WAIT = "wait" - - -def user_and_host() -> str: - """Get the current user and hostname.""" - username = getpass.getuser() - hostname = socket.gethostname() - - return f"{username}@{hostname}" - - -def session_metadata(ctx: McpContext) -> dict: - """Generate metadata for the current session.""" - return { - "session_id": ctx.local_session_id, - "system_user": user_and_host(), - "mcp_client": ctx.mcp_client_name, - "mcp_server": ctx.mcp_server_name, - "tools": ctx.tools, - **(ctx.extra_metadata or {}), - } +session_store = McpSessionsManager() def write_as_utf8_bytes(data: dict) -> bytes: @@ -65,326 +38,37 @@ def write_as_utf8_bytes(data: dict) -> bytes: return json.dumps(data).encode(UTF_8) + b"\n" -def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list: - """Deduplicate new_annotations using the annotations in the context.""" - deduped_annotations = [] - for annotation in new_annotations: - if annotation not in ctx.annotations: - deduped_annotations.append(annotation) - return deduped_annotations - - -def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool: - """Checks if there are new errors in the guardrails result.""" - annotations = create_annotations_from_guardrails_errors( - guardrails_result.get("errors", []) - ) - for annotation in annotations: - if annotation not in ctx.annotations: - return True - return False - - -async def append_and_push_trace( - ctx: McpContext, message: dict, guardrails_result: dict -) -> None: - """ - Append a message to the trace if it exists or create a new one - and push it to the Invariant Explorer. - """ - annotations = [] - if guardrails_result and guardrails_result.get("errors", []): - annotations = create_annotations_from_guardrails_errors( - guardrails_result["errors"] - ) - - if ctx.guardrails.logging_guardrails: - logging_guardrails_check_result = await get_guardrails_check_result( - ctx, message, action=GuardrailAction.LOG - ) - if logging_guardrails_check_result and logging_guardrails_check_result.get( - "errors", [] - ): - annotations.extend( - create_annotations_from_guardrails_errors( - logging_guardrails_check_result["errors"] - ) - ) - deduplicated_annotations = deduplicate_annotations(ctx, annotations) - - try: - # If the trace_id is None, create a new trace with the messages. - # Otherwise, append the message to the existing trace. - client = AsyncClient( - api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL), - ) - - if ctx.trace_id is None: - ctx.trace.append(message) - - # default metadata - metadata = {"source": "mcp"} - # include MCP session metadata - metadata.update(session_metadata(ctx)) - - response = await client.push_trace( - PushTracesRequest( - messages=[ctx.trace], - dataset=ctx.explorer_dataset, - metadata=[metadata], - annotations=[deduplicated_annotations], - ) - ) - ctx.trace_id = response.id[0] - ctx.last_trace_length = len(ctx.trace) - ctx.annotations.extend(deduplicated_annotations) - else: - ctx.trace.append(message) - response = await client.append_messages( - AppendMessagesRequest( - trace_id=ctx.trace_id, - messages=ctx.trace[ctx.last_trace_length :], - annotations=deduplicated_annotations, - ) - ) - ctx.last_trace_length = len(ctx.trace) - ctx.annotations.extend(deduplicated_annotations) - except Exception as e: # pylint: disable=broad-except - mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e) - - -async def get_guardrails_check_result( - ctx: McpContext, - message: dict, - action: GuardrailAction = GuardrailAction.BLOCK, -) -> dict: - """ - Check against guardrails of type action. - Works in both sync and async contexts by always using a dedicated thread. - """ - # Skip if no guardrails are configured for this action - if not ( - (ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK) - or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG) - ): - return {} - - # Prepare context and select appropriate guardrails - context = RequestContext.create( - request_json={}, - dataset_name=ctx.explorer_dataset, - invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), - guardrails=ctx.guardrails, - guardrails_parameters={"metadata": session_metadata(ctx), "action": action}, - ) - - guardrails_to_check = ( - ctx.guardrails.blocking_guardrails - if action == GuardrailAction.BLOCK - else ctx.guardrails.logging_guardrails - ) - - return run_task_sync( - check_guardrails, - messages=ctx.trace + [message], - guardrails=guardrails_to_check, - context=context, - ) - - -def json_rpc_error_response( - id_value: str | int, error_message: str, response_type: str = "error" -) -> dict: - """ - Create a JSON-RPC error response with either error object or content format. - - Args: - id_value: The ID of the JSON-RPC request - error_message: The error message to include - response_type: Either "error" or "content" to determine response format - - Returns: - A properly formatted JSON-RPC response dictionary - """ - base_response = { - "jsonrpc": "2.0", - "id": id_value, - } - - if response_type == "error": - base_response["error"] = { - "code": -32600, - "message": error_message, - } - else: - base_response["result"] = { - "content": [ - { - "type": "text", - "text": error_message, - } - ] - } - - return base_response - - -async def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]: - """ - Hook function to intercept tool calls. - - If the request is blocked, it returns a tuple with a message explaining the block - and a flag indicating the request was blocked. - Otherwise it returns the original request and a flag indicating it was not blocked. - """ - tool_call = { - "id": f"call_{request.get('id')}", - "type": "function", - "function": { - "name": request["params"]["name"], - "arguments": request["params"]["arguments"], - }, - } - - message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} - - # Check for blocking guardrails - guardrailing_result = await get_guardrails_check_result( - ctx, message, action=GuardrailAction.BLOCK - ) - - # If the request is blocked, return a message indicating the block reason. - if ( - guardrailing_result - and guardrailing_result.get("errors", []) - and check_if_new_errors(ctx, guardrailing_result) - ): - if ctx.push_explorer: - await append_and_push_trace(ctx, message, guardrailing_result) - else: - ctx.trace.append(message) - - return json_rpc_error_response( - request.get("id"), - INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrailing_result["errors"], - response_type=ctx.failure_response_format, - ), True - - # Add the message to the trace - ctx.trace.append(message) - return request, False - - -async def hook_tool_result(ctx: McpContext, result: dict) -> dict: - """ - Hook function to intercept tool results. - Returns the potentially modified result. - """ - method = ctx.id_to_method_mapping.get(result.get("id")) - call_id = f"call_{result.get('id')}" - - # Safely handle result object - result_obj = result.get("result", {}) - if result_obj.get(MCP_SERVER_INFO): - ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "") - - if not method: - return result - elif method == MCP_TOOL_CALL: - message = { - "role": "tool", - "content": result_obj.get("content"), - "error": result_obj.get("error"), - "tool_call_id": call_id, - } - # Check for blocking guardrails - guardrailing_result = await get_guardrails_check_result( - ctx, message, action=GuardrailAction.BLOCK - ) - - if guardrailing_result and guardrailing_result.get("errors", []): - result = json_rpc_error_response( - result.get("id"), - INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % format_errors_in_response(guardrailing_result["errors"]), - response_type=ctx.failure_response_format, # Using content type as that's what the original code used - ) - - if ctx.push_explorer: - await append_and_push_trace(ctx, message, guardrailing_result) - else: - ctx.trace.append(message) - - return result - elif method == MCP_LIST_TOOLS: - ctx.tools = result_obj.get("tools") - message = { - "role": "tool", - "content": json.dumps(result.get("result").get("tools")), - "tool_call_id": call_id, - } - # next validate it with guardrails - guardrailing_result = await get_guardrails_check_result( - ctx, message, action=GuardrailAction.BLOCK - ) - if guardrailing_result and guardrailing_result.get("errors", []): - result["result"]["tools"] = [ - { - "name": "blocked_" + tool["name"], - "description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE - % format_errors_in_response(guardrailing_result["errors"]), - # no parameters - "inputSchema": { - "properties": {}, - "required": [], - "title": "invariant_mcp_server_blockedArguments", - "type": "object", - }, - "annotations": { - "title": "This tool was blocked by security guardrails.", - }, - } - for tool in result["result"]["tools"] - ] - - # add it to the session trace (and run logging guardrails) - if ctx.push_explorer: - await append_and_push_trace(ctx, message, guardrailing_result) - else: - ctx.trace.append(message) - - return result - else: - return result - - async def stream_and_forward_stdout( - mcp_process: subprocess.Popen, ctx: McpContext + session_id: str, mcp_process: subprocess.Popen ) -> None: """Read from the mcp_process stdout, apply guardrails and forward to sys.stdout""" loop = asyncio.get_event_loop() - while True: + if mcp_process.poll() is not None: + mcp_log(f"[ERROR] MCP process terminated with code: {mcp_process.poll()}") + break + line = await loop.run_in_executor(None, mcp_process.stdout.readline) if not line: break try: # Process complete JSON lines - line_str = line.decode(UTF_8).strip() - if not line_str: + decoded_line = line.decode(UTF_8).strip() + if not decoded_line: continue + session = session_store.get_session(session_id) + if session.attributes.verbose: + mcp_log(f"[INFO] server -> client: {decoded_line}") + response_body = json.loads(decoded_line) + update_mcp_server_in_session_metadata(session, response_body) - if ctx.verbose: - mcp_log(f"[INFO] server -> client: {line_str}") - - parsed_json = json.loads(line_str) - processed_json = await hook_tool_result(ctx, parsed_json) - + intercept_response_result, _ = await intercept_response( + session_id, session_store, response_body + ) # Write and flush immediately - sys.stdout.buffer.write(write_as_utf8_bytes(processed_json)) + sys.stdout.buffer.write(write_as_utf8_bytes(intercept_response_result)) sys.stdout.buffer.flush() - except Exception as e: # pylint: disable=broad-except mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}") if line: @@ -407,74 +91,52 @@ async def stream_and_forward_stderr( MCP_LOG_FILE.buffer.flush() -async def process_line( - ctx: McpContext, mcp_process: subprocess.Popen, line: bytes +async def _intercept_request( + session_id: str, mcp_process: subprocess.Popen, line: bytes ) -> None: - """Process a line of input from stdin, decode it, and forward to mcp_process.""" - if ctx.verbose: + """ + Process a line of input from stdin, decode it and check for guardrails. + If the request is blocked, it returns a message indicating the block reason + otherwise it forwards the request to mcp_process stdin. + """ + session = session_store.get_session(session_id) + if session.attributes.verbose: mcp_log(f"[INFO] client -> server: {line}") # Try to decode and parse as JSON to check for tool calls try: text = line.decode(UTF_8) - parsed_json = json.loads(text) + request_body = json.loads(text) except json.JSONDecodeError as je: mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}") mcp_log(f"[ERROR] Problematic line: {line[:200]}...") return + update_session_from_request(session, request_body) + # Refresh guardrails + await session.load_guardrails() - if parsed_json.get(MCP_METHOD) is not None: - ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(MCP_METHOD) - if parsed_json.get(MCP_PARAMS) and parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO): - ctx.mcp_client_name = ( - parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") + hook_tool_call_result = {} + is_blocked = False + if request_body.get(MCP_METHOD) == MCP_TOOL_CALL: + hook_tool_call_result, is_blocked = await hook_tool_call( + session_id, session_store, request_body ) - - # Check if this is a tool call request - if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL: - # Refresh guardrails - run_task_sync(ctx.load_guardrails) - - # Intercept and potentially block modify the request - hook_tool_call_result, is_blocked = await hook_tool_call(ctx, parsed_json) - if not is_blocked: - # If blocked, hook_tool_call_result contains the original request. - # Forward the request to the MCP process. - # It will handle the request and return a response. - mcp_process.stdin.write(write_as_utf8_bytes(hook_tool_call_result)) - mcp_process.stdin.flush() - else: - # If blocked, hook_tool_call_result contains the block message. - # Forward the block message result back to the caller. - # The original request is not passed to the MCP process. - sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result)) - sys.stdout.buffer.flush() + elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS: + hook_tool_call_result, is_blocked = await hook_tool_call( + session_id=session_id, + session_store=session_store, + request_body={ + "id": request_body.get("id"), + "method": MCP_LIST_TOOLS, + "params": {"name": MCP_LIST_TOOLS, "arguments": {}}, + }, + ) + if is_blocked: + sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result)) + sys.stdout.buffer.flush() return - else: - # pass through the request to the MCP process - - # for list_tools, extend the trace by a tool call - if parsed_json.get(MCP_METHOD) == MCP_LIST_TOOLS: - # Refresh guardrails - run_task_sync(ctx.load_guardrails) - ctx.trace.append( - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": f"call_{parsed_json.get('id')}", - "type": "function", - "function": { - "name": "tools/list", - "arguments": {}, - }, - } - ], - } - ) - mcp_process.stdin.write(write_as_utf8_bytes(parsed_json)) - mcp_process.stdin.flush() + mcp_process.stdin.write(write_as_utf8_bytes(request_body)) + mcp_process.stdin.flush() async def wait_for_stdin_input( @@ -523,7 +185,7 @@ async def wait_for_stdin_input( async def run_stdio_input_loop( - ctx: McpContext, + session_id: str, mcp_process: subprocess.Popen, stdout_task: asyncio.Task, stderr_task: asyncio.Task, @@ -557,7 +219,7 @@ async def run_stdio_input_loop( if not line: continue - await process_line(ctx, mcp_process, line) + await _intercept_request(session_id, mcp_process, line) except (BrokenPipeError, KeyboardInterrupt): # Broken pipe = client disappeared, just start shutdown mcp_log("Client disconnected or keyboard interrupt") @@ -570,7 +232,7 @@ async def run_stdio_input_loop( while b"\n" in buffer: line, buffer = buffer.split(b"\n", 1) if line: - await process_line(ctx, mcp_process, line) + await _intercept_request(session_id, mcp_process, line) # Terminate process if needed if mcp_process.poll() is None: @@ -631,7 +293,11 @@ async def execute(args: list[str] = None): mcp_log("[INFO] Running with Python version:", sys.version) mcp_gateway_args, mcp_server_command_args = split_args(args) - ctx = McpContext(mcp_gateway_args) + session_id = generate_session_id() + await session_store.initialize_session( + session_id, + McpAttributes.from_cli_args(mcp_gateway_args), + ) mcp_process = subprocess.Popen( mcp_server_command_args, @@ -642,8 +308,10 @@ async def execute(args: list[str] = None): ) # Start async tasks for stdout and stderr - stdout_task = asyncio.create_task(stream_and_forward_stdout(mcp_process, ctx)) + stdout_task = asyncio.create_task( + stream_and_forward_stdout(session_id, mcp_process) + ) stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process)) # Handle forwarding stdin and intercept tool calls - await run_stdio_input_loop(ctx, mcp_process, stdout_task, stderr_task) + await run_stdio_input_loop(session_id, mcp_process, stdout_task, stderr_task) diff --git a/gateway/mcp/mcp_context.py b/gateway/mcp/mcp_context.py deleted file mode 100644 index 4da2923..0000000 --- a/gateway/mcp/mcp_context.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Context manager for MCP (Model Context Protocol) gateway.""" - -import argparse -import os -import random -import uuid -from typing import Dict - -from gateway.integrations.explorer import ( - fetch_guardrails_from_explorer, -) -from gateway.common.guardrails import GuardrailRuleSet - - -class McpContext: - """Singleton class to manage MCP context and state.""" - - _instance = None - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super(McpContext, cls).__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self, cli_args: list): - if not hasattr(self, "_initialized"): - self._initialized = False - if self._initialized: - return - - config, extra_args = self._parse_cli_args(cli_args) - - # The project name is used to identify the dataset in Invariant Explorer. - self.explorer_dataset = config.project_name - # whether to push traces to Invariant Explorer - self.push_explorer = config.push_explorer - # the format to use to communicate guardrail failures to the client - self.failure_response_format = config.failure_response_format - # verbose logging of in/out - self.verbose = config.verbose - - # trace of this MCP session - self.trace = [] - # tools available to the MCP server - self.tools = [] - - # configured guardrails - self.guardrails = GuardrailRuleSet( - blocking_guardrails=[], logging_guardrails=[] - ) - - # parsed from CLI (all --metadata-* args) - self.extra_metadata: Dict[str, str] = {} - for arg in extra_args: - assert "=" in arg, f"Invalid extra metadata argument: {arg}" - key, value = arg.split("=") - assert key.startswith( - "--metadata-" - ), f"Invalid extra metadata argument: {arg}, must start with --metadata-" - key = key[len("--metadata-") :] - self.extra_metadata[key] = value - - # captured from MCP calls/responses - self.mcp_client_name = "" - self.mcp_server_name = "" - - # We send the same trace messages for guardrails analysis multiple times. - # We need to deduplicate them before sending to the explorer. - self.annotations = [] - self.trace_id = None - self.local_session_id = str(uuid.uuid4()) - self.last_trace_length = 0 - self.id_to_method_mapping = {} - self._initialized = True - - def _parse_cli_args(self, cli_args: list) -> argparse.Namespace: - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="MCP Gateway") - parser.add_argument( - "--project-name", - help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.", - type=str, - default=f"mcp-capture-{random.randint(1, 100)}", - ) - parser.add_argument( - "--push-explorer", - help="Enable pushing traces to Invariant Explorer", - action="store_true", - ) - parser.add_argument( - "--verbose", - help="Enable verbose logging", - action="store_true", - ) - parser.add_argument( - "--failure-response-format", - help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)", - type=str, - default="error", - ) - - return parser.parse_known_args(cli_args) - - async def load_guardrails(self): - """Run async setup logic (e.g. fetching guardrails).""" - self.guardrails = await fetch_guardrails_from_explorer( - self.explorer_dataset, - "Bearer " + os.getenv("INVARIANT_API_KEY"), - self.extra_metadata.get("client", self.mcp_client_name), - self.extra_metadata.get("server", self.mcp_server_name), - ) diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index 4f71a6f..3c27ebf 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -15,10 +15,6 @@ from gateway.common.constants import ( MCP_METHOD, MCP_TOOL_CALL, MCP_LIST_TOOLS, - MCP_PARAMS, - MCP_RESULT, - MCP_SERVER_INFO, - MCP_CLIENT_INFO, UTF_8, ) from gateway.common.mcp_sessions_manager import ( @@ -28,7 +24,9 @@ from gateway.common.mcp_sessions_manager import ( from gateway.common.mcp_utils import ( get_mcp_server_base_url, hook_tool_call, - hook_tool_call_response, + intercept_response, + update_mcp_server_in_session_metadata, + update_session_from_request, ) MCP_SERVER_POST_HEADERS = { @@ -72,16 +70,7 @@ async def mcp_post_sse_gateway( request_body_bytes = await request.body() request_body = json.loads(request_body_bytes) session = session_store.get_session(session_id) - if request_body.get(MCP_METHOD) and request_body.get("id"): - session.id_to_method_mapping[request_body.get("id")] = request_body.get( - MCP_METHOD - ) - if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( - MCP_CLIENT_INFO - ): - session.attributes.metadata["mcp_client"] = ( - request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") - ) + update_session_from_request(session, request_body) if request_body.get(MCP_METHOD) == MCP_TOOL_CALL: # Intercept and potentially block the request @@ -137,8 +126,6 @@ async def mcp_post_sse_gateway( print(f"[MCP POST] Request error: {str(e)}") raise HTTPException(status_code=500, detail="Request error") from e except Exception as e: - import traceback - traceback.print_exc() print(f"[MCP POST] Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Unexpected error") from e @@ -340,59 +327,18 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8) session = session_store.get_session(session_id) try: - response_json = json.loads(sse.data) + response_body = json.loads(sse.data) + update_mcp_server_in_session_metadata(session, response_body) - if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get( - MCP_SERVER_INFO - ): - session.attributes.metadata["mcp_server"] = ( - response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") + intercept_response_result, is_blocked = await intercept_response( + session_id=session_id, + session_store=session_store, + response_body=response_body, + ) + if is_blocked: + event_bytes = f"event: {sse.event}\ndata: {json.dumps(intercept_response_result)}\n\n".encode( + UTF_8 ) - - method = session.id_to_method_mapping.get(response_json.get("id")) - if method == MCP_TOOL_CALL: - result, blocked = await hook_tool_call_response( - session_id=session_id, - session_store=session_store, - response_json=response_json, - ) - # Update the event bytes with hook_tool_call_response. - # hook_tool_call_response is same as response_json if no guardrail is violated. - # If guardrail is violated, it contains the error message. - # pylint: disable=line-too-long - if blocked: - event_bytes = ( - f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8) - ) - elif method == MCP_LIST_TOOLS: - # store tools in metadata - session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get( - MCP_RESULT - ).get("tools") - # store tools/list tool call in trace - result, blocked = await hook_tool_call_response( - session_id=session_id, - session_store=session_store, - response_json={ - "id": response_json.get("id"), - "result": { - "content": json.dumps( - response_json.get(MCP_RESULT).get("tools") - ), - "tools": response_json.get(MCP_RESULT).get("tools"), - }, - }, - is_tools_list=True, - ) - # Update the event bytes with hook_tool_call_response. - # hook_tool_call_response is same as response_json if no guardrail is violated. - # If guardrail is violated, it contains the error message. - # pylint: disable=line-too-long - if blocked: - event_bytes = ( - f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8) - ) - except json.JSONDecodeError as e: print( f"[MCP SSE] Error parsing message JSON: {e}", diff --git a/gateway/routes/mcp_streamable.py b/gateway/routes/mcp_streamable.py index bc5c2e8..4d806d9 100644 --- a/gateway/routes/mcp_streamable.py +++ b/gateway/routes/mcp_streamable.py @@ -1,9 +1,6 @@ """Gateway service to forward requests to the MCP Streamable HTTP servers""" import json -import uuid - -from typing import Tuple import httpx @@ -13,12 +10,8 @@ from fastapi.responses import StreamingResponse from gateway.common.constants import ( CLIENT_TIMEOUT, INVARIANT_SESSION_ID_PREFIX, - MCP_CLIENT_INFO, MCP_LIST_TOOLS, MCP_METHOD, - MCP_PARAMS, - MCP_RESULT, - MCP_SERVER_INFO, MCP_TOOL_CALL, UTF_8, ) @@ -27,9 +20,13 @@ from gateway.common.mcp_sessions_manager import ( McpAttributes, ) from gateway.common.mcp_utils import ( + generate_session_id, get_mcp_server_base_url, hook_tool_call, - hook_tool_call_response, + intercept_response, + update_mcp_client_info_in_session, + update_mcp_server_in_session_metadata, + update_tool_call_id_in_session, ) gateway = APIRouter() @@ -69,7 +66,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse: # If a session ID is provided in the request headers, it was already initialized # in McpSessionsManager. This might be a session ID returned by the MCP server # or a session ID generated in the gateway. - _update_tool_call_id_in_session(session_id, request_body) + update_tool_call_id_in_session( + session_store.get_session(session_id), request_body + ) elif is_initialization_request: # If this is an initialization request, we generate a session ID, # We don't call initialize_session here because we don't know @@ -77,7 +76,7 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse: # If later in the response from MCP server, we don't receive a session ID then this # will be initialized and returned back to the client else this will be # overwritten by the session ID returned by the MCP server. - session_id = _generate_session_id() + session_id = generate_session_id() # Intercept the request and check for guardrails. if not is_initialization_request: @@ -115,9 +114,9 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse: # Update client info if this is an initialization request if is_initialization_request: - _update_mcp_client_info_in_session( - session_id=session_id, - request_body=request_body, + update_mcp_client_info_in_session( + session_store.get_session(session_id), + request_body, ) # If the response is JSON type, handle it as a JSON response. @@ -281,52 +280,17 @@ def _get_mcp_server_endpoint(request: Request) -> str: return get_mcp_server_base_url(request) + "/mcp/" -def _generate_session_id() -> str: - """ - Generate a new session ID. - If the MCP server is session less then we don't have a session ID from the MCP server. - """ - return INVARIANT_SESSION_ID_PREFIX + uuid.uuid4().hex - - -def _update_tool_call_id_in_session(session_id: str, request_body: dict) -> None: - """ - Updates the tool call ID in the session. - """ - session = session_store.get_session(session_id) - if request_body.get(MCP_METHOD) and request_body.get("id"): - session.id_to_method_mapping[request_body.get("id")] = request_body.get( - MCP_METHOD - ) - - -def _update_mcp_client_info_in_session(session_id: str, request_body: dict) -> None: - """ - Update the MCP client info in the session metadata. - """ - session = session_store.get_session(session_id) - if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( - MCP_CLIENT_INFO - ): - session.attributes.metadata["mcp_client"] = ( - request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") - ) - - def _update_mcp_response_info_in_session( - session_id: str, response_json: dict, is_json_response: bool + session_id: str, response_body: dict, is_json_response: bool ) -> None: """ Update the MCP response info in the session metadata. """ session = session_store.get_session(session_id) - if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get( - MCP_SERVER_INFO - ): - session.attributes.metadata["mcp_server"] = ( - response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") - ) - session.attributes.metadata["server_response_type"] = "json" if is_json_response else "sse" + update_mcp_server_in_session_metadata(session, response_body) + session.attributes.metadata["server_response_type"] = ( + "json" if is_json_response else "sse" + ) def _is_initialization_request(request_body: dict) -> bool: @@ -353,18 +317,20 @@ async def _handle_mcp_json_response( # return the error message else return the response as is response_content = response.content # The server response is empty string when client sends "notifications/initialized" - response_json = ( + response_body = ( json.loads(response_content.decode(UTF_8)) if response_content else {} ) - if response_json: + if response_body: _update_mcp_response_info_in_session( - session_id=session_id, response_json=response_json, is_json_response=True + session_id=session_id, response_body=response_body, is_json_response=True ) response_code = response.status_code if not is_initialization_request: - intercept_response_result, blocked = await _intercept_response( - session_id=session_id, response_json=response_json + intercept_response_result, blocked = await intercept_response( + session_id=session_id, + session_store=session_store, + response_body=response_body, ) if blocked: response_content = json.dumps(intercept_response_result).encode(UTF_8) @@ -405,14 +371,15 @@ async def _handle_mcp_streaming_response( if not stripped_line: break # End of stream if buffer: - response_json = json.loads(stripped_line.split("data: ")[1].strip()) + response_body = json.loads(stripped_line.split("data: ")[1].strip()) if not is_initialization_request: ( intercept_response_result, blocked, - ) = await _intercept_response( + ) = await intercept_response( session_id=session_id, - response_json=response_json, + session_store=session_store, + response_body=response_body, ) if blocked: yield ( @@ -423,7 +390,7 @@ async def _handle_mcp_streaming_response( else: _update_mcp_response_info_in_session( session_id=session_id, - response_json=response_json, + response_body=response_body, is_json_response=False, ) yield f"{buffer}\n{stripped_line}\n\n" @@ -482,46 +449,3 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response | media_type="application/json", ) return None - - -async def _intercept_response( - session_id: str, response_json: dict -) -> Tuple[dict, bool]: - """ - Intercept the response and check for guardrails. - This function is used to intercept responses and check for guardrails. - If the response is blocked, it returns a message indicating the block - reason with a boolean flag set to True. If the response is not blocked, - it returns the original response with a boolean flag set to False. - """ - session = session_store.get_session(session_id) - method = session.id_to_method_mapping.get(response_json.get("id")) - # Intercept and potentially block tool call response - if method == MCP_TOOL_CALL: - result, blocked = await hook_tool_call_response( - session_id=session_id, - session_store=session_store, - response_json=response_json, - ) - return result, blocked - # Intercept and potentially block list tool call response - elif method == MCP_LIST_TOOLS: - # store tools in metadata - session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get( - MCP_RESULT - ).get("tools") - # store tools/list tool call in trace - result, blocked = await hook_tool_call_response( - session_id=session_id, - session_store=session_store, - response_json={ - "id": response_json.get("id"), - "result": { - "content": json.dumps(response_json.get(MCP_RESULT).get("tools")), - "tools": response_json.get(MCP_RESULT).get("tools"), - }, - }, - is_tools_list=True, - ) - return result, blocked - return response_json, False diff --git a/tests/integration/resources/mcp/stdio/client/main.py b/tests/integration/resources/mcp/stdio/client/main.py index 559a87a..ebc55b5 100644 --- a/tests/integration/resources/mcp/stdio/client/main.py +++ b/tests/integration/resources/mcp/stdio/client/main.py @@ -1,5 +1,6 @@ """A MCP client implementation that interacts with MCP server to make tool calls.""" +import asyncio import os from datetime import timedelta @@ -48,7 +49,6 @@ class MCPClient: for key, value in metadata_keys.items(): args.append("--metadata-" + key + "=" + value) - if push_to_explorer: args.append("--push-explorer") args.extend( @@ -133,7 +133,7 @@ async def run( project_name, server_script_path, push_to_explorer, - metadata_keys=metadata_keys + metadata_keys=metadata_keys, ) listed_tools = await client.session.list_tools() if tool_name == "tools/list": @@ -142,4 +142,6 @@ async def run( else: return await client.call_tool(tool_name, tool_args) finally: + if push_to_explorer: + await asyncio.sleep(2) await client.cleanup()