mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-07-02 17:15:41 +02:00
Make MCP stdio gateway fully async. With sync and async mixed behaviour for running background tasks we were running into issues.
This commit is contained in:
+187
-141
@@ -5,7 +5,7 @@ import subprocess
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
@@ -14,6 +14,9 @@ from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
from gateway.common.constants import (
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
MCP_METHOD,
|
||||
MCP_CLIENT_INFO,
|
||||
MCP_PARAMS,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
)
|
||||
@@ -23,7 +26,6 @@ from gateway.integrations.explorer import create_annotations_from_guardrails_err
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
|
||||
|
||||
UTF_8_ENCODING = "utf-8"
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
@@ -69,16 +71,50 @@ def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
ctx: McpContext,
|
||||
message: dict,
|
||||
action: GuardrailAction = GuardrailAction.BLOCK,
|
||||
) -> dict:
|
||||
"""
|
||||
Check against guardrails of type action in an async manner.
|
||||
"""
|
||||
# Skip if no guardrails are configured for this action
|
||||
if not (
|
||||
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
|
||||
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=ctx.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=ctx.guardrails,
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
ctx.guardrails.blocking_guardrails
|
||||
if action == GuardrailAction.BLOCK
|
||||
else ctx.guardrails.logging_guardrails
|
||||
)
|
||||
|
||||
# Run check_guardrails asynchronously
|
||||
return await check_guardrails(
|
||||
messages=ctx.trace + [message],
|
||||
guardrails=guardrails_to_check,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
async def append_and_push_trace(
|
||||
ctx: McpContext, message: dict, guardrails_result: dict
|
||||
) -> None:
|
||||
"""
|
||||
Append a message to the trace if it exists or create a new one
|
||||
and push it to the Invariant Explorer.
|
||||
|
||||
This function runs asynchronously in the background.
|
||||
"""
|
||||
|
||||
annotations = []
|
||||
if guardrails_result and guardrails_result.get("errors", []):
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
@@ -86,7 +122,7 @@ async def append_and_push_trace(
|
||||
)
|
||||
|
||||
if ctx.guardrails.logging_guardrails:
|
||||
logging_guardrails_check_result = get_guardrails_check_result(
|
||||
logging_guardrails_check_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.LOG
|
||||
)
|
||||
if logging_guardrails_check_result and logging_guardrails_check_result.get(
|
||||
@@ -138,45 +174,7 @@ async def append_and_push_trace(
|
||||
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
|
||||
|
||||
|
||||
def get_guardrails_check_result(
|
||||
ctx: McpContext,
|
||||
message: dict,
|
||||
action: GuardrailAction = GuardrailAction.BLOCK,
|
||||
) -> dict:
|
||||
"""
|
||||
Check against guardrails of type action.
|
||||
Works in both sync and async contexts by always using a dedicated thread.
|
||||
"""
|
||||
# Skip if no guardrails are configured for this action
|
||||
if not (
|
||||
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
|
||||
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=ctx.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=ctx.guardrails,
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
ctx.guardrails.blocking_guardrails
|
||||
if action == GuardrailAction.BLOCK
|
||||
else ctx.guardrails.logging_guardrails
|
||||
)
|
||||
|
||||
return run_task_sync(
|
||||
check_guardrails,
|
||||
messages=ctx.trace + [message],
|
||||
guardrails=guardrails_to_check,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
|
||||
async def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
|
||||
"""
|
||||
Hook function to intercept tool calls.
|
||||
|
||||
@@ -195,23 +193,20 @@ def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
|
||||
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
guardrailing_result = get_guardrails_check_result(
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrailing_result
|
||||
and guardrailing_result.get("errors", [])
|
||||
and check_if_new_errors(ctx, guardrailing_result)
|
||||
):
|
||||
if ctx.push_explorer:
|
||||
run_task_in_background(
|
||||
append_and_push_trace, ctx, message, guardrailing_result
|
||||
)
|
||||
await append_and_push_trace(ctx, message, guardrailing_result)
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
@@ -227,28 +222,30 @@ def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
|
||||
return request, False
|
||||
|
||||
|
||||
def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
"""
|
||||
Hook function to intercept tool results.
|
||||
Modify this function to change behavior for tool results.
|
||||
Returns the potentially modified result.
|
||||
"""
|
||||
method = ctx.id_to_method_mapping.get(result.get("id"))
|
||||
call_id = f"call_{result.get('id')}"
|
||||
if "serverInfo" in result.get("result"):
|
||||
ctx.mcp_server_name = result.get("result").get("serverInfo").get("name", "")
|
||||
|
||||
# Safely handle result object
|
||||
result_obj = result.get("result", {})
|
||||
if isinstance(result_obj, dict) and MCP_SERVER_INFO in result_obj:
|
||||
ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "")
|
||||
|
||||
if method is None:
|
||||
return result
|
||||
elif method == MCP_TOOL_CALL:
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": result.get("result").get("content"),
|
||||
"error": result.get("result").get("error"),
|
||||
"content": result_obj.get("content"),
|
||||
"error": result_obj.get("error"),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
guardrailing_result = get_guardrails_check_result(
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = await get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
@@ -264,21 +261,27 @@ def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
}
|
||||
|
||||
if ctx.push_explorer:
|
||||
# Run append_and_push_trace in background
|
||||
run_task_in_background(
|
||||
append_and_push_trace, ctx, message, guardrailing_result
|
||||
)
|
||||
await append_and_push_trace(ctx, message, guardrailing_result)
|
||||
|
||||
return result
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
ctx.tools = result.get("result").get("tools")
|
||||
ctx.tools = result_obj.get("tools")
|
||||
return result
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> None:
|
||||
"""Read from the mcp_process stdout, apply guardrails and and forward to sys.stdout"""
|
||||
for line in iter(mcp_process.stdout.readline, b""):
|
||||
async def stream_and_forward_stdout(
|
||||
mcp_process: subprocess.Popen, ctx: McpContext
|
||||
) -> None:
|
||||
"""Read from the mcp_process stdout, apply guardrails and forward to sys.stdout"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while True:
|
||||
line = await loop.run_in_executor(None, mcp_process.stdout.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
try:
|
||||
# Process complete JSON lines
|
||||
line_str = line.decode(UTF_8_ENCODING).strip()
|
||||
@@ -286,14 +289,11 @@ def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) ->
|
||||
continue
|
||||
|
||||
parsed_json = json.loads(line_str)
|
||||
processed_json = hook_tool_result(ctx, parsed_json)
|
||||
processed_json = await hook_tool_result(ctx, parsed_json)
|
||||
|
||||
# Write and flush immediately
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
|
||||
sys.stdout.buffer.flush()
|
||||
except json.JSONDecodeError as je:
|
||||
mcp_log(f"[ERROR] JSON decode error in stdout processing: {str(je)}")
|
||||
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
|
||||
|
||||
except Exception as e:
|
||||
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
|
||||
@@ -301,79 +301,133 @@ def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) ->
|
||||
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
|
||||
|
||||
|
||||
def stream_and_forward_stderr(
|
||||
mcp_process: subprocess.Popen, ctx: McpContext, read_chunk_size: int = 1
|
||||
async def stream_and_forward_stderr(
|
||||
mcp_process: subprocess.Popen, read_chunk_size: int = 10
|
||||
) -> None:
|
||||
"""Read from the mcp_process stderr and write to sys.stderr"""
|
||||
for line in iter(lambda: mcp_process.stderr.read(read_chunk_size), b""):
|
||||
MCP_LOG_FILE.buffer.write(line)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
while True:
|
||||
# Read chunks asynchronously
|
||||
chunk = await loop.run_in_executor(
|
||||
None, lambda: mcp_process.stderr.read(read_chunk_size)
|
||||
)
|
||||
|
||||
MCP_LOG_FILE.buffer.write(chunk)
|
||||
MCP_LOG_FILE.buffer.flush()
|
||||
|
||||
|
||||
def run_stdio_input_loop(ctx: McpContext, mcp_process: subprocess.Popen) -> None:
|
||||
async def process_line(
|
||||
ctx: McpContext, mcp_process: subprocess.Popen, line: bytes
|
||||
) -> None:
|
||||
"""Process a line of input from stdin, check for tool calls and forward to mcp_process."""
|
||||
try:
|
||||
text = line.decode(UTF_8_ENCODING)
|
||||
parsed_json = json.loads(text)
|
||||
|
||||
if parsed_json.get(MCP_METHOD) is not None:
|
||||
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if MCP_PARAMS in parsed_json and MCP_CLIENT_INFO in parsed_json.get(MCP_PARAMS):
|
||||
ctx.mcp_client_name = (
|
||||
parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
# Check if this is a tool call request
|
||||
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Refresh guardrails
|
||||
await ctx.load_guardrails()
|
||||
|
||||
# Intercept and potentially block/modify the request
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(ctx, parsed_json)
|
||||
if not is_blocked:
|
||||
# Forward the request to the MCP process
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(hook_tool_call_result))
|
||||
mcp_process.stdin.flush()
|
||||
else:
|
||||
# Forward the block message result back to the caller
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(hook_tool_call_result))
|
||||
sys.stdout.buffer.flush()
|
||||
else:
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
|
||||
mcp_process.stdin.flush()
|
||||
except Exception: # pylint: disable=bare-except
|
||||
# Not a complete or valid JSON, just pass through
|
||||
pass
|
||||
|
||||
|
||||
async def run_stdio_input_loop(
|
||||
ctx: McpContext,
|
||||
mcp_process: subprocess.Popen,
|
||||
stdout_task: asyncio.Task,
|
||||
stderr_task: asyncio.Task,
|
||||
) -> None:
|
||||
"""Handle standard input, intercept call and forward requests to mcp_process stdin."""
|
||||
loop = asyncio.get_event_loop()
|
||||
stdin_fd = sys.stdin.fileno()
|
||||
buffer = b""
|
||||
|
||||
# Set stdin to non-blocking mode
|
||||
os.set_blocking(stdin_fd, False)
|
||||
|
||||
try:
|
||||
while True:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0.1)
|
||||
# Check for input using select
|
||||
ready, _, _ = await loop.run_in_executor(
|
||||
None, lambda: select.select([stdin_fd], [], [], 0.1)
|
||||
)
|
||||
|
||||
if not ready:
|
||||
# No input available, yield to other tasks
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
line = sys.stdin.buffer.readline()
|
||||
if not line:
|
||||
break
|
||||
# Read available data
|
||||
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
|
||||
if not chunk:
|
||||
break # EOF
|
||||
|
||||
# Try to decode and parse as JSON to check for tool calls
|
||||
buffer += chunk
|
||||
|
||||
# Process complete lines
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
|
||||
await process_line(ctx, mcp_process, line)
|
||||
|
||||
except (BrokenPipeError, KeyboardInterrupt):
|
||||
# Broken pipe = client disappeared, just start shutdown
|
||||
mcp_log("Client disconnected or keyboard interrupt")
|
||||
finally:
|
||||
# Close stdin
|
||||
if mcp_process.stdin:
|
||||
mcp_process.stdin.close()
|
||||
|
||||
# Process any remaining data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if line:
|
||||
await process_line(ctx, mcp_process, line)
|
||||
|
||||
# Terminate process if needed
|
||||
if mcp_process.poll() is None:
|
||||
mcp_process.terminate()
|
||||
try:
|
||||
text = line.decode(UTF_8_ENCODING)
|
||||
parsed_json = json.loads(text)
|
||||
if parsed_json.get(MCP_METHOD) is not None:
|
||||
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if "params" in parsed_json and "clientInfo" in parsed_json.get(
|
||||
"params"
|
||||
):
|
||||
ctx.mcp_client_name = (
|
||||
parsed_json.get("params").get("clientInfo").get("name", "")
|
||||
)
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, mcp_process.wait), timeout=2
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
mcp_process.kill()
|
||||
|
||||
# Check if this is a tool call request
|
||||
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Refresh guardrails
|
||||
run_task_sync(ctx.load_guardrails)
|
||||
# Cancel I/O tasks
|
||||
stdout_task.cancel()
|
||||
stderr_task.cancel()
|
||||
|
||||
# Intercept and potentially block modify the request
|
||||
hook_tool_call_result, is_blocked = hook_tool_call(ctx, parsed_json)
|
||||
if not is_blocked:
|
||||
# If blocked, hook_tool_call_result contains the original request.
|
||||
# Forward the request to the MCP process.
|
||||
# It will handle the request and return a response.
|
||||
mcp_process.stdin.write(
|
||||
write_as_utf8_bytes(hook_tool_call_result)
|
||||
)
|
||||
mcp_process.stdin.flush()
|
||||
else:
|
||||
# If blocked, hook_tool_call_result contains the block message.
|
||||
# Forward the block message result back to the caller.
|
||||
# The original request is not passed to the MCP process.
|
||||
sys.stdout.buffer.write(
|
||||
write_as_utf8_bytes(hook_tool_call_result)
|
||||
)
|
||||
sys.stdout.buffer.flush()
|
||||
continue
|
||||
else:
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
|
||||
mcp_process.stdin.flush()
|
||||
continue
|
||||
except Exception:
|
||||
# Not a complete or valid JSON, just pass through
|
||||
pass
|
||||
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
mcp_process.terminate()
|
||||
# Final flush
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
|
||||
@@ -425,17 +479,9 @@ async def execute(args: list[str] = None):
|
||||
bufsize=0,
|
||||
)
|
||||
|
||||
# Start threads to forward stdout and stderr
|
||||
threading.Thread(
|
||||
target=stream_and_forward_stdout,
|
||||
args=(mcp_process, ctx),
|
||||
daemon=True,
|
||||
).start()
|
||||
threading.Thread(
|
||||
target=stream_and_forward_stderr,
|
||||
args=(mcp_process, ctx),
|
||||
daemon=True,
|
||||
).start()
|
||||
# Start async tasks for stdout and stderr
|
||||
stdout_task = asyncio.create_task(stream_and_forward_stdout(mcp_process, ctx))
|
||||
stderr_task = asyncio.create_task(stream_and_forward_stderr(mcp_process))
|
||||
|
||||
# Handle forwarding stdin and intercept tool calls
|
||||
run_stdio_input_loop(ctx, mcp_process)
|
||||
await run_stdio_input_loop(ctx, mcp_process, stdout_task, stderr_task)
|
||||
|
||||
@@ -17,7 +17,7 @@ MCP_SSE_SERVER_PORT = 8123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(15)
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize(
|
||||
"push_to_explorer, transport",
|
||||
[
|
||||
@@ -97,7 +97,7 @@ async def test_mcp_with_gateway(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(15)
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
@@ -205,11 +205,13 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
tool_call_annotation is not None
|
||||
), "Missing 'get_last_message_from_user is called' annotation"
|
||||
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
|
||||
assert food_annotation["extra_metadata"]["guardrail"]["action"] == "log"
|
||||
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
|
||||
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(15)
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
@@ -298,10 +300,11 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
and annotations[0]["address"] == "messages.0.tool_calls.0"
|
||||
)
|
||||
assert annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
assert annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(15)
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
async def test_mcp_sse_with_gateway_hybrid_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
@@ -416,4 +419,6 @@ async def test_mcp_sse_with_gateway_hybrid_guardrails(
|
||||
tool_call_annotation is not None
|
||||
), "Missing 'get_last_message_from_user is called' annotation"
|
||||
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
|
||||
assert food_annotation["extra_metadata"]["guardrail"]["action"] == "block"
|
||||
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"
|
||||
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""A MCP client implementation that interacts with MCP server to make tool calls."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from datetime import timedelta
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Optional
|
||||
@@ -69,7 +69,7 @@ class MCPClient:
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(
|
||||
self.stdio, self.write, read_timeout_seconds=timedelta(seconds=10)
|
||||
self.stdio, self.write, read_timeout_seconds=timedelta(seconds=15)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -85,10 +85,6 @@ class MCPClient:
|
||||
tool_name: Name of the tool to call
|
||||
tool_args: Arguments for the tool call
|
||||
"""
|
||||
response = await self.session.list_tools()
|
||||
if tool_name not in [tool.name for tool in response.tools]:
|
||||
raise ValueError(f"Tool '{tool_name}' not found in available tools")
|
||||
|
||||
# Execute tool call
|
||||
result = await self.session.call_tool(tool_name, tool_args)
|
||||
return result
|
||||
@@ -130,8 +126,4 @@ async def run(
|
||||
)
|
||||
return await client.call_tool(tool_name, tool_args)
|
||||
finally:
|
||||
# Sleep for a while to allow the server to process the background tasks
|
||||
# like pushing traces to the explorer
|
||||
if push_to_explorer:
|
||||
await asyncio.sleep(2)
|
||||
await client.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user