diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py index 29f225f..2074974 100644 --- a/gateway/mcp/mcp.py +++ b/gateway/mcp/mcp.py @@ -5,7 +5,7 @@ import subprocess import json import os import select -import threading +import asyncio from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.append_messages import AppendMessagesRequest @@ -14,6 +14,9 @@ from invariant_sdk.types.push_traces import PushTracesRequest from gateway.common.constants import ( INVARIANT_GUARDRAILS_BLOCKED_MESSAGE, MCP_METHOD, + MCP_CLIENT_INFO, + MCP_PARAMS, + MCP_SERVER_INFO, MCP_TOOL_CALL, MCP_LIST_TOOLS, ) @@ -23,7 +26,6 @@ from gateway.integrations.explorer import create_annotations_from_guardrails_err from gateway.integrations.guardrails import check_guardrails from gateway.mcp.log import mcp_log, MCP_LOG_FILE from gateway.mcp.mcp_context import McpContext -from gateway.mcp.task_utils import run_task_in_background, run_task_sync UTF_8_ENCODING = "utf-8" DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -69,16 +71,50 @@ def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool: return False +async def get_guardrails_check_result( + ctx: McpContext, + message: dict, + action: GuardrailAction = GuardrailAction.BLOCK, +) -> dict: + """ + Check against guardrails of type action in an async manner. + """ + # Skip if no guardrails are configured for this action + if not ( + (ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK) + or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG) + ): + return {} + + # Prepare context and select appropriate guardrails + context = RequestContext.create( + request_json={}, + dataset_name=ctx.explorer_dataset, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + guardrails=ctx.guardrails, + ) + + guardrails_to_check = ( + ctx.guardrails.blocking_guardrails + if action == GuardrailAction.BLOCK + else ctx.guardrails.logging_guardrails + ) + + # Run check_guardrails asynchronously + return await check_guardrails( + messages=ctx.trace + [message], + guardrails=guardrails_to_check, + context=context, + ) + + async def append_and_push_trace( ctx: McpContext, message: dict, guardrails_result: dict ) -> None: """ Append a message to the trace if it exists or create a new one and push it to the Invariant Explorer. - - This function runs asynchronously in the background. """ - annotations = [] if guardrails_result and guardrails_result.get("errors", []): annotations = create_annotations_from_guardrails_errors( @@ -86,7 +122,7 @@ async def append_and_push_trace( ) if ctx.guardrails.logging_guardrails: - logging_guardrails_check_result = get_guardrails_check_result( + logging_guardrails_check_result = await get_guardrails_check_result( ctx, message, action=GuardrailAction.LOG ) if logging_guardrails_check_result and logging_guardrails_check_result.get( @@ -138,45 +174,7 @@ async def append_and_push_trace( mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e) -def get_guardrails_check_result( - ctx: McpContext, - message: dict, - action: GuardrailAction = GuardrailAction.BLOCK, -) -> dict: - """ - Check against guardrails of type action. - Works in both sync and async contexts by always using a dedicated thread. - """ - # Skip if no guardrails are configured for this action - if not ( - (ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK) - or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG) - ): - return {} - - # Prepare context and select appropriate guardrails - context = RequestContext.create( - request_json={}, - dataset_name=ctx.explorer_dataset, - invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), - guardrails=ctx.guardrails, - ) - - guardrails_to_check = ( - ctx.guardrails.blocking_guardrails - if action == GuardrailAction.BLOCK - else ctx.guardrails.logging_guardrails - ) - - return run_task_sync( - check_guardrails, - messages=ctx.trace + [message], - guardrails=guardrails_to_check, - context=context, - ) - - -def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]: +async def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]: """ Hook function to intercept tool calls. @@ -195,23 +193,20 @@ def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]: message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} - # Check for blocking guardrails - this blocks until completion - guardrailing_result = get_guardrails_check_result( + # Check for blocking guardrails + guardrailing_result = await get_guardrails_check_result( ctx, message, action=GuardrailAction.BLOCK ) # If the request is blocked, return a message indicating the block reason. - # If there are new errors, run append_and_push_trace in background. - # If there are no new errors, just return the original request. if ( guardrailing_result and guardrailing_result.get("errors", []) and check_if_new_errors(ctx, guardrailing_result) ): if ctx.push_explorer: - run_task_in_background( - append_and_push_trace, ctx, message, guardrailing_result - ) + await append_and_push_trace(ctx, message, guardrailing_result) + return { "jsonrpc": "2.0", "id": request.get("id"), @@ -227,28 +222,30 @@ def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]: return request, False -def hook_tool_result(ctx: McpContext, result: dict) -> dict: +async def hook_tool_result(ctx: McpContext, result: dict) -> dict: """ Hook function to intercept tool results. - Modify this function to change behavior for tool results. Returns the potentially modified result. """ method = ctx.id_to_method_mapping.get(result.get("id")) call_id = f"call_{result.get('id')}" - if "serverInfo" in result.get("result"): - ctx.mcp_server_name = result.get("result").get("serverInfo").get("name", "") + + # Safely handle result object + result_obj = result.get("result", {}) + if isinstance(result_obj, dict) and MCP_SERVER_INFO in result_obj: + ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "") if method is None: return result elif method == MCP_TOOL_CALL: message = { "role": "tool", - "content": result.get("result").get("content"), - "error": result.get("result").get("error"), + "content": result_obj.get("content"), + "error": result_obj.get("error"), "tool_call_id": call_id, } - # Check for blocking guardrails - this blocks until completion - guardrailing_result = get_guardrails_check_result( + # Check for blocking guardrails + guardrailing_result = await get_guardrails_check_result( ctx, message, action=GuardrailAction.BLOCK ) @@ -264,21 +261,27 @@ def hook_tool_result(ctx: McpContext, result: dict) -> dict: } if ctx.push_explorer: - # Run append_and_push_trace in background - run_task_in_background( - append_and_push_trace, ctx, message, guardrailing_result - ) + await append_and_push_trace(ctx, message, guardrailing_result) + return result elif method == MCP_LIST_TOOLS: - ctx.tools = result.get("result").get("tools") + ctx.tools = result_obj.get("tools") return result else: return result -def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> None: - """Read from the mcp_process stdout, apply guardrails and and forward to sys.stdout""" - for line in iter(mcp_process.stdout.readline, b""): +async def stream_and_forward_stdout( + mcp_process: subprocess.Popen, ctx: McpContext +) -> None: + """Read from the mcp_process stdout, apply guardrails and forward to sys.stdout""" + loop = asyncio.get_event_loop() + + while True: + line = await loop.run_in_executor(None, mcp_process.stdout.readline) + if not line: + break + try: # Process complete JSON lines line_str = line.decode(UTF_8_ENCODING).strip() @@ -286,14 +289,11 @@ def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> continue parsed_json = json.loads(line_str) - processed_json = hook_tool_result(ctx, parsed_json) + processed_json = await hook_tool_result(ctx, parsed_json) # Write and flush immediately sys.stdout.buffer.write(write_as_utf8_bytes(processed_json)) sys.stdout.buffer.flush() - except json.JSONDecodeError as je: - mcp_log(f"[ERROR] JSON decode error in stdout processing: {str(je)}") - mcp_log(f"[ERROR] Problematic line: {line[:200]}...") except Exception as e: mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}") @@ -301,79 +301,133 @@ def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...") -def stream_and_forward_stderr( - mcp_process: subprocess.Popen, ctx: McpContext, read_chunk_size: int = 1 +async def stream_and_forward_stderr( + mcp_process: subprocess.Popen, read_chunk_size: int = 10 ) -> None: """Read from the mcp_process stderr and write to sys.stderr""" - for line in iter(lambda: mcp_process.stderr.read(read_chunk_size), b""): - MCP_LOG_FILE.buffer.write(line) + loop = asyncio.get_event_loop() + + while True: + # Read chunks asynchronously + chunk = await loop.run_in_executor( + None, lambda: mcp_process.stderr.read(read_chunk_size) + ) + + MCP_LOG_FILE.buffer.write(chunk) MCP_LOG_FILE.buffer.flush() -def run_stdio_input_loop(ctx: McpContext, mcp_process: subprocess.Popen) -> None: +async def process_line( + ctx: McpContext, mcp_process: subprocess.Popen, line: bytes +) -> None: + """Process a line of input from stdin, check for tool calls and forward to mcp_process.""" + try: + text = line.decode(UTF_8_ENCODING) + parsed_json = json.loads(text) + + if parsed_json.get(MCP_METHOD) is not None: + ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get( + MCP_METHOD + ) + if MCP_PARAMS in parsed_json and MCP_CLIENT_INFO in parsed_json.get(MCP_PARAMS): + ctx.mcp_client_name = ( + parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") + ) + + # Check if this is a tool call request + if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL: + # Refresh guardrails + await ctx.load_guardrails() + + # Intercept and potentially block/modify the request + hook_tool_call_result, is_blocked = await hook_tool_call(ctx, parsed_json) + if not is_blocked: + # Forward the request to the MCP process + mcp_process.stdin.write(write_as_utf8_bytes(hook_tool_call_result)) + mcp_process.stdin.flush() + else: + # Forward the block message result back to the caller + sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result)) + sys.stdout.buffer.flush() + else: + mcp_process.stdin.write(write_as_utf8_bytes(parsed_json)) + mcp_process.stdin.flush() + except Exception: # pylint: disable=bare-except + # Not a complete or valid JSON, just pass through + pass + + +async def run_stdio_input_loop( + ctx: McpContext, + mcp_process: subprocess.Popen, + stdout_task: asyncio.Task, + stderr_task: asyncio.Task, +) -> None: """Handle standard input, intercept call and forward requests to mcp_process stdin.""" + loop = asyncio.get_event_loop() + stdin_fd = sys.stdin.fileno() + buffer = b"" + + # Set stdin to non-blocking mode + os.set_blocking(stdin_fd, False) try: while True: - ready, _, _ = select.select([sys.stdin], [], [], 0.1) + # Check for input using select + ready, _, _ = await loop.run_in_executor( + None, lambda: select.select([stdin_fd], [], [], 0.1) + ) + if not ready: + # No input available, yield to other tasks + await asyncio.sleep(0.01) continue - line = sys.stdin.buffer.readline() - if not line: - break + # Read available data + chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096)) + if not chunk: + break # EOF - # Try to decode and parse as JSON to check for tool calls + buffer += chunk + + # Process complete lines + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + if not line: + continue + + await process_line(ctx, mcp_process, line) + + except (BrokenPipeError, KeyboardInterrupt): + # Broken pipe = client disappeared, just start shutdown + mcp_log("Client disconnected or keyboard interrupt") + finally: + # Close stdin + if mcp_process.stdin: + mcp_process.stdin.close() + + # Process any remaining data + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + if line: + await process_line(ctx, mcp_process, line) + + # Terminate process if needed + if mcp_process.poll() is None: + mcp_process.terminate() try: - text = line.decode(UTF_8_ENCODING) - parsed_json = json.loads(text) - if parsed_json.get(MCP_METHOD) is not None: - ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get( - MCP_METHOD - ) - if "params" in parsed_json and "clientInfo" in parsed_json.get( - "params" - ): - ctx.mcp_client_name = ( - parsed_json.get("params").get("clientInfo").get("name", "") - ) + await asyncio.wait_for( + loop.run_in_executor(None, mcp_process.wait), timeout=2 + ) + except asyncio.TimeoutError: + mcp_process.kill() - # Check if this is a tool call request - if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL: - # Refresh guardrails - run_task_sync(ctx.load_guardrails) + # Cancel I/O tasks + stdout_task.cancel() + stderr_task.cancel() - # Intercept and potentially block modify the request - hook_tool_call_result, is_blocked = hook_tool_call(ctx, parsed_json) - if not is_blocked: - # If blocked, hook_tool_call_result contains the original request. - # Forward the request to the MCP process. - # It will handle the request and return a response. - mcp_process.stdin.write( - write_as_utf8_bytes(hook_tool_call_result) - ) - mcp_process.stdin.flush() - else: - # If blocked, hook_tool_call_result contains the block message. - # Forward the block message result back to the caller. - # The original request is not passed to the MCP process. - sys.stdout.buffer.write( - write_as_utf8_bytes(hook_tool_call_result) - ) - sys.stdout.buffer.flush() - continue - else: - mcp_process.stdin.write(write_as_utf8_bytes(parsed_json)) - mcp_process.stdin.flush() - continue - except Exception: - # Not a complete or valid JSON, just pass through - pass - - except BrokenPipeError: - pass - except KeyboardInterrupt: - mcp_process.terminate() + # Final flush + sys.stdout.flush() def split_args(args: list[str] = None) -> tuple[list[str], list[str]]: @@ -425,17 +479,9 @@ async def execute(args: list[str] = None): bufsize=0, ) - # Start threads to forward stdout and stderr - threading.Thread( - target=stream_and_forward_stdout, - args=(mcp_process, ctx), - daemon=True, - ).start() - threading.Thread( - target=stream_and_forward_stderr, - args=(mcp_process, ctx), - daemon=True, - ).start() + # Start async tasks for stdout and stderr + stdout_task = asyncio.create_task(stream_and_forward_stdout(mcp_process, ctx)) + stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process)) # Handle forwarding stdin and intercept tool calls - run_stdio_input_loop(ctx, mcp_process) + await run_stdio_input_loop(ctx, mcp_process, stdout_task, stderr_task) diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index d326992..eee28bf 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -17,7 +17,7 @@ MCP_SSE_SERVER_PORT = 8123 @pytest.mark.asyncio -@pytest.mark.timeout(15) +@pytest.mark.timeout(30) @pytest.mark.parametrize( "push_to_explorer, transport", [ @@ -97,7 +97,7 @@ async def test_mcp_with_gateway( @pytest.mark.asyncio -@pytest.mark.timeout(15) +@pytest.mark.timeout(30) @pytest.mark.parametrize("transport", ["stdio", "sse"]) async def test_mcp_with_gateway_and_logging_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport @@ -205,11 +205,13 @@ async def test_mcp_with_gateway_and_logging_guardrails( tool_call_annotation is not None ), "Missing 'get_last_message_from_user is called' annotation" assert food_annotation["extra_metadata"]["source"] == "guardrails-error" + assert food_annotation["extra_metadata"]["guardrail"]["action"] == "log" assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error" + assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log" @pytest.mark.asyncio -@pytest.mark.timeout(15) +@pytest.mark.timeout(30) @pytest.mark.parametrize("transport", ["stdio", "sse"]) async def test_mcp_with_gateway_and_blocking_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport @@ -298,10 +300,11 @@ async def test_mcp_with_gateway_and_blocking_guardrails( and annotations[0]["address"] == "messages.0.tool_calls.0" ) assert annotations[0]["extra_metadata"]["source"] == "guardrails-error" + assert annotations[0]["extra_metadata"]["guardrail"]["action"] == "block" @pytest.mark.asyncio -@pytest.mark.timeout(15) +@pytest.mark.timeout(30) @pytest.mark.parametrize("transport", ["stdio", "sse"]) async def test_mcp_sse_with_gateway_hybrid_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport @@ -416,4 +419,6 @@ async def test_mcp_sse_with_gateway_hybrid_guardrails( tool_call_annotation is not None ), "Missing 'get_last_message_from_user is called' annotation" assert food_annotation["extra_metadata"]["source"] == "guardrails-error" + assert food_annotation["extra_metadata"]["guardrail"]["action"] == "block" assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error" + assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log" diff --git a/tests/integration/resources/mcp/stdio/client/main.py b/tests/integration/resources/mcp/stdio/client/main.py index 020db99..0b12814 100644 --- a/tests/integration/resources/mcp/stdio/client/main.py +++ b/tests/integration/resources/mcp/stdio/client/main.py @@ -1,7 +1,7 @@ """A MCP client implementation that interacts with MCP server to make tool calls.""" -import asyncio import os + from datetime import timedelta from contextlib import AsyncExitStack from typing import Any, Optional @@ -69,7 +69,7 @@ class MCPClient: self.stdio, self.write = stdio_transport self.session = await self.exit_stack.enter_async_context( ClientSession( - self.stdio, self.write, read_timeout_seconds=timedelta(seconds=10) + self.stdio, self.write, read_timeout_seconds=timedelta(seconds=15) ) ) @@ -85,10 +85,6 @@ class MCPClient: tool_name: Name of the tool to call tool_args: Arguments for the tool call """ - response = await self.session.list_tools() - if tool_name not in [tool.name for tool in response.tools]: - raise ValueError(f"Tool '{tool_name}' not found in available tools") - # Execute tool call result = await self.session.call_tool(tool_name, tool_args) return result @@ -130,8 +126,4 @@ async def run( ) return await client.call_tool(tool_name, tool_args) finally: - # Sleep for a while to allow the server to process the background tasks - # like pushing traces to the explorer - if push_to_explorer: - await asyncio.sleep(2) await client.cleanup()