Make MCP stdio gateway fully async. With sync and async mixed behaviour for running background tasks we were running into issues.

This commit is contained in:
Hemang
2025-05-15 16:01:57 +02:00
committed by Hemang Sarkar
parent a6c1124076
commit 876eb44c78
3 changed files with 198 additions and 155 deletions
+187 -141
View File
@@ -5,7 +5,7 @@ import subprocess
import json
import os
import select
import threading
import asyncio
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
@@ -14,6 +14,9 @@ from invariant_sdk.types.push_traces import PushTracesRequest
from gateway.common.constants import (
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
MCP_METHOD,
MCP_CLIENT_INFO,
MCP_PARAMS,
MCP_SERVER_INFO,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
)
@@ -23,7 +26,6 @@ from gateway.integrations.explorer import create_annotations_from_guardrails_err
from gateway.integrations.guardrails import check_guardrails
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
UTF_8_ENCODING = "utf-8"
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -69,16 +71,50 @@ def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
return False
async def get_guardrails_check_result(
ctx: McpContext,
message: dict,
action: GuardrailAction = GuardrailAction.BLOCK,
) -> dict:
"""
Check against guardrails of type action in an async manner.
"""
# Skip if no guardrails are configured for this action
if not (
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
):
return {}
# Prepare context and select appropriate guardrails
context = RequestContext.create(
request_json={},
dataset_name=ctx.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=ctx.guardrails,
)
guardrails_to_check = (
ctx.guardrails.blocking_guardrails
if action == GuardrailAction.BLOCK
else ctx.guardrails.logging_guardrails
)
# Run check_guardrails asynchronously
return await check_guardrails(
messages=ctx.trace + [message],
guardrails=guardrails_to_check,
context=context,
)
async def append_and_push_trace(
ctx: McpContext, message: dict, guardrails_result: dict
) -> None:
"""
Append a message to the trace if it exists or create a new one
and push it to the Invariant Explorer.
This function runs asynchronously in the background.
"""
annotations = []
if guardrails_result and guardrails_result.get("errors", []):
annotations = create_annotations_from_guardrails_errors(
@@ -86,7 +122,7 @@ async def append_and_push_trace(
)
if ctx.guardrails.logging_guardrails:
logging_guardrails_check_result = get_guardrails_check_result(
logging_guardrails_check_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.LOG
)
if logging_guardrails_check_result and logging_guardrails_check_result.get(
@@ -138,45 +174,7 @@ async def append_and_push_trace(
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
def get_guardrails_check_result(
ctx: McpContext,
message: dict,
action: GuardrailAction = GuardrailAction.BLOCK,
) -> dict:
"""
Check against guardrails of type action.
Works in both sync and async contexts by always using a dedicated thread.
"""
# Skip if no guardrails are configured for this action
if not (
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
):
return {}
# Prepare context and select appropriate guardrails
context = RequestContext.create(
request_json={},
dataset_name=ctx.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=ctx.guardrails,
)
guardrails_to_check = (
ctx.guardrails.blocking_guardrails
if action == GuardrailAction.BLOCK
else ctx.guardrails.logging_guardrails
)
return run_task_sync(
check_guardrails,
messages=ctx.trace + [message],
guardrails=guardrails_to_check,
context=context,
)
def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
async def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
"""
Hook function to intercept tool calls.
@@ -195,23 +193,20 @@ def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails - this blocks until completion
guardrailing_result = get_guardrails_check_result(
# Check for blocking guardrails
guardrailing_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrailing_result
and guardrailing_result.get("errors", [])
and check_if_new_errors(ctx, guardrailing_result)
):
if ctx.push_explorer:
run_task_in_background(
append_and_push_trace, ctx, message, guardrailing_result
)
await append_and_push_trace(ctx, message, guardrailing_result)
return {
"jsonrpc": "2.0",
"id": request.get("id"),
@@ -227,28 +222,30 @@ def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
return request, False
def hook_tool_result(ctx: McpContext, result: dict) -> dict:
async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
"""
Hook function to intercept tool results.
Modify this function to change behavior for tool results.
Returns the potentially modified result.
"""
method = ctx.id_to_method_mapping.get(result.get("id"))
call_id = f"call_{result.get('id')}"
if "serverInfo" in result.get("result"):
ctx.mcp_server_name = result.get("result").get("serverInfo").get("name", "")
# Safely handle result object
result_obj = result.get("result", {})
if isinstance(result_obj, dict) and MCP_SERVER_INFO in result_obj:
ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "")
if method is None:
return result
elif method == MCP_TOOL_CALL:
message = {
"role": "tool",
"content": result.get("result").get("content"),
"error": result.get("result").get("error"),
"content": result_obj.get("content"),
"error": result_obj.get("error"),
"tool_call_id": call_id,
}
# Check for blocking guardrails - this blocks until completion
guardrailing_result = get_guardrails_check_result(
# Check for blocking guardrails
guardrailing_result = await get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
@@ -264,21 +261,27 @@ def hook_tool_result(ctx: McpContext, result: dict) -> dict:
}
if ctx.push_explorer:
# Run append_and_push_trace in background
run_task_in_background(
append_and_push_trace, ctx, message, guardrailing_result
)
await append_and_push_trace(ctx, message, guardrailing_result)
return result
elif method == MCP_LIST_TOOLS:
ctx.tools = result.get("result").get("tools")
ctx.tools = result_obj.get("tools")
return result
else:
return result
def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> None:
"""Read from the mcp_process stdout, apply guardrails and and forward to sys.stdout"""
for line in iter(mcp_process.stdout.readline, b""):
async def stream_and_forward_stdout(
mcp_process: subprocess.Popen, ctx: McpContext
) -> None:
"""Read from the mcp_process stdout, apply guardrails and forward to sys.stdout"""
loop = asyncio.get_event_loop()
while True:
line = await loop.run_in_executor(None, mcp_process.stdout.readline)
if not line:
break
try:
# Process complete JSON lines
line_str = line.decode(UTF_8_ENCODING).strip()
@@ -286,14 +289,11 @@ def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) ->
continue
parsed_json = json.loads(line_str)
processed_json = hook_tool_result(ctx, parsed_json)
processed_json = await hook_tool_result(ctx, parsed_json)
# Write and flush immediately
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
sys.stdout.buffer.flush()
except json.JSONDecodeError as je:
mcp_log(f"[ERROR] JSON decode error in stdout processing: {str(je)}")
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
except Exception as e:
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
@@ -301,79 +301,133 @@ def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) ->
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
def stream_and_forward_stderr(
mcp_process: subprocess.Popen, ctx: McpContext, read_chunk_size: int = 1
async def stream_and_forward_stderr(
mcp_process: subprocess.Popen, read_chunk_size: int = 10
) -> None:
"""Read from the mcp_process stderr and write to sys.stderr"""
for line in iter(lambda: mcp_process.stderr.read(read_chunk_size), b""):
MCP_LOG_FILE.buffer.write(line)
loop = asyncio.get_event_loop()
while True:
# Read chunks asynchronously
chunk = await loop.run_in_executor(
None, lambda: mcp_process.stderr.read(read_chunk_size)
)
MCP_LOG_FILE.buffer.write(chunk)
MCP_LOG_FILE.buffer.flush()
def run_stdio_input_loop(ctx: McpContext, mcp_process: subprocess.Popen) -> None:
async def process_line(
ctx: McpContext, mcp_process: subprocess.Popen, line: bytes
) -> None:
"""Process a line of input from stdin, check for tool calls and forward to mcp_process."""
try:
text = line.decode(UTF_8_ENCODING)
parsed_json = json.loads(text)
if parsed_json.get(MCP_METHOD) is not None:
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
MCP_METHOD
)
if MCP_PARAMS in parsed_json and MCP_CLIENT_INFO in parsed_json.get(MCP_PARAMS):
ctx.mcp_client_name = (
parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
# Check if this is a tool call request
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
# Refresh guardrails
await ctx.load_guardrails()
# Intercept and potentially block/modify the request
hook_tool_call_result, is_blocked = await hook_tool_call(ctx, parsed_json)
if not is_blocked:
# Forward the request to the MCP process
mcp_process.stdin.write(write_as_utf8_bytes(hook_tool_call_result))
mcp_process.stdin.flush()
else:
# Forward the block message result back to the caller
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
sys.stdout.buffer.flush()
else:
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
mcp_process.stdin.flush()
except Exception: # pylint: disable=bare-except
# Not a complete or valid JSON, just pass through
pass
async def run_stdio_input_loop(
ctx: McpContext,
mcp_process: subprocess.Popen,
stdout_task: asyncio.Task,
stderr_task: asyncio.Task,
) -> None:
"""Handle standard input, intercept call and forward requests to mcp_process stdin."""
loop = asyncio.get_event_loop()
stdin_fd = sys.stdin.fileno()
buffer = b""
# Set stdin to non-blocking mode
os.set_blocking(stdin_fd, False)
try:
while True:
ready, _, _ = select.select([sys.stdin], [], [], 0.1)
# Check for input using select
ready, _, _ = await loop.run_in_executor(
None, lambda: select.select([stdin_fd], [], [], 0.1)
)
if not ready:
# No input available, yield to other tasks
await asyncio.sleep(0.01)
continue
line = sys.stdin.buffer.readline()
if not line:
break
# Read available data
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
if not chunk:
break # EOF
# Try to decode and parse as JSON to check for tool calls
buffer += chunk
# Process complete lines
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
await process_line(ctx, mcp_process, line)
except (BrokenPipeError, KeyboardInterrupt):
# Broken pipe = client disappeared, just start shutdown
mcp_log("Client disconnected or keyboard interrupt")
finally:
# Close stdin
if mcp_process.stdin:
mcp_process.stdin.close()
# Process any remaining data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if line:
await process_line(ctx, mcp_process, line)
# Terminate process if needed
if mcp_process.poll() is None:
mcp_process.terminate()
try:
text = line.decode(UTF_8_ENCODING)
parsed_json = json.loads(text)
if parsed_json.get(MCP_METHOD) is not None:
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
MCP_METHOD
)
if "params" in parsed_json and "clientInfo" in parsed_json.get(
"params"
):
ctx.mcp_client_name = (
parsed_json.get("params").get("clientInfo").get("name", "")
)
await asyncio.wait_for(
loop.run_in_executor(None, mcp_process.wait), timeout=2
)
except asyncio.TimeoutError:
mcp_process.kill()
# Check if this is a tool call request
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
# Refresh guardrails
run_task_sync(ctx.load_guardrails)
# Cancel I/O tasks
stdout_task.cancel()
stderr_task.cancel()
# Intercept and potentially block modify the request
hook_tool_call_result, is_blocked = hook_tool_call(ctx, parsed_json)
if not is_blocked:
# If blocked, hook_tool_call_result contains the original request.
# Forward the request to the MCP process.
# It will handle the request and return a response.
mcp_process.stdin.write(
write_as_utf8_bytes(hook_tool_call_result)
)
mcp_process.stdin.flush()
else:
# If blocked, hook_tool_call_result contains the block message.
# Forward the block message result back to the caller.
# The original request is not passed to the MCP process.
sys.stdout.buffer.write(
write_as_utf8_bytes(hook_tool_call_result)
)
sys.stdout.buffer.flush()
continue
else:
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
mcp_process.stdin.flush()
continue
except Exception:
# Not a complete or valid JSON, just pass through
pass
except BrokenPipeError:
pass
except KeyboardInterrupt:
mcp_process.terminate()
# Final flush
sys.stdout.flush()
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
@@ -425,17 +479,9 @@ async def execute(args: list[str] = None):
bufsize=0,
)
# Start threads to forward stdout and stderr
threading.Thread(
target=stream_and_forward_stdout,
args=(mcp_process, ctx),
daemon=True,
).start()
threading.Thread(
target=stream_and_forward_stderr,
args=(mcp_process, ctx),
daemon=True,
).start()
# Start async tasks for stdout and stderr
stdout_task = asyncio.create_task(stream_and_forward_stdout(mcp_process, ctx))
stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process))
# Handle forwarding stdin and intercept tool calls
run_stdio_input_loop(ctx, mcp_process)
await run_stdio_input_loop(ctx, mcp_process, stdout_task, stderr_task)
+9 -4
View File
@@ -17,7 +17,7 @@ MCP_SSE_SERVER_PORT = 8123
@pytest.mark.asyncio
@pytest.mark.timeout(15)
@pytest.mark.timeout(30)
@pytest.mark.parametrize(
"push_to_explorer, transport",
[
@@ -97,7 +97,7 @@ async def test_mcp_with_gateway(
@pytest.mark.asyncio
@pytest.mark.timeout(15)
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_with_gateway_and_logging_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
@@ -205,11 +205,13 @@ async def test_mcp_with_gateway_and_logging_guardrails(
tool_call_annotation is not None
), "Missing 'get_last_message_from_user is called' annotation"
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
assert food_annotation["extra_metadata"]["guardrail"]["action"] == "log"
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
@pytest.mark.asyncio
@pytest.mark.timeout(15)
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_with_gateway_and_blocking_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
@@ -298,10 +300,11 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
and annotations[0]["address"] == "messages.0.tool_calls.0"
)
assert annotations[0]["extra_metadata"]["source"] == "guardrails-error"
assert annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
@pytest.mark.asyncio
@pytest.mark.timeout(15)
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_sse_with_gateway_hybrid_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
@@ -416,4 +419,6 @@ async def test_mcp_sse_with_gateway_hybrid_guardrails(
tool_call_annotation is not None
), "Missing 'get_last_message_from_user is called' annotation"
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
assert food_annotation["extra_metadata"]["guardrail"]["action"] == "block"
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
@@ -1,7 +1,7 @@
"""A MCP client implementation that interacts with MCP server to make tool calls."""
import asyncio
import os
from datetime import timedelta
from contextlib import AsyncExitStack
from typing import Any, Optional
@@ -69,7 +69,7 @@ class MCPClient:
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(
self.stdio, self.write, read_timeout_seconds=timedelta(seconds=10)
self.stdio, self.write, read_timeout_seconds=timedelta(seconds=15)
)
)
@@ -85,10 +85,6 @@ class MCPClient:
tool_name: Name of the tool to call
tool_args: Arguments for the tool call
"""
response = await self.session.list_tools()
if tool_name not in [tool.name for tool in response.tools]:
raise ValueError(f"Tool '{tool_name}' not found in available tools")
# Execute tool call
result = await self.session.call_tool(tool_name, tool_args)
return result
@@ -130,8 +126,4 @@ async def run(
)
return await client.call_tool(tool_name, tool_args)
finally:
# Sleep for a while to allow the server to process the background tasks
# like pushing traces to the explorer
if push_to_explorer:
await asyncio.sleep(2)
await client.cleanup()