Small cleanups in sse and stdio MCP implementation.

This commit is contained in:
Hemang
2025-05-21 14:05:51 +02:00
committed by Hemang Sarkar
parent 418c873e04
commit 169eb066b9
7 changed files with 85 additions and 177 deletions

View File

@@ -15,6 +15,7 @@ IGNORED_HEADERS = [
CLIENT_TIMEOUT = 60.0
# MCP related constants
UTF_8 = "utf-8"
MCP_METHOD = "method"
MCP_TOOL_CALL = "tools/call"
MCP_LIST_TOOLS = "tools/list"

View File

@@ -65,23 +65,8 @@ class McpSession(BaseModel):
"""Deduplicate new_annotations using the annotations in the session."""
deduped_annotations = []
for annotation in new_annotations:
# Check if an annotation with the same content and address exists in self.annotations
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
# to remove duplicates instead of using a custom logic.
# This is a temporary solution until the Invariant SDK is updated.
is_duplicate = False
for current_annotation in self.annotations:
if (
annotation.content == current_annotation.content
and annotation.address == current_annotation.address
and annotation.extra_metadata == current_annotation.extra_metadata
):
is_duplicate = True
break
if not is_duplicate:
if annotation not in self.annotations:
deduped_annotations.append(annotation)
return deduped_annotations
@contextlib.asynccontextmanager

View File

@@ -66,9 +66,6 @@ def create_annotations_from_guardrails_errors(
)
)
# Remove duplicates
# TODO: Rely on the __eq__ and __hash__ methods of the AnnotationCreate class
# to remove duplicates instead of using a custom function.
# This is a temporary solution until the Invariant SDK is updated.
return remove_duplicates(annotations)
@@ -85,7 +82,7 @@ def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCre
for annotation in annotations:
# Convert the entire extra_metadata dict to a JSON string
# This creates a hashable representation regardless of nested content
metadata_str = json.dumps(annotation.extra_metadata, sort_keys=True)
metadata_str = json.dumps(annotation.extra_metadata or {}, sort_keys=True)
# Create a unique identifier using all three fields
unique_key = (annotation.content, annotation.address, metadata_str)

View File

@@ -1,12 +1,14 @@
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
import sys
import subprocess
import asyncio
import getpass
import json
import os
import select
import asyncio
import platform
import select
import socket
import subprocess
import sys
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
@@ -21,6 +23,7 @@ from gateway.common.constants import (
MCP_SERVER_INFO,
MCP_TOOL_CALL,
MCP_LIST_TOOLS,
UTF_8,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.request_context import RequestContext
@@ -28,15 +31,17 @@ from gateway.integrations.explorer import create_annotations_from_guardrails_err
from gateway.integrations.guardrails import check_guardrails
from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response
from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
import getpass
import socket
from gateway.mcp.task_utils import run_task_sync
UTF_8_ENCODING = "utf-8"
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
STATUS_EOF = "eof"
STATUS_DATA = "data"
STATUS_WAIT = "wait"
def user_and_host() -> str:
"""Get the current user and hostname."""
username = getpass.getuser()
hostname = socket.gethostname()
@@ -44,6 +49,7 @@ def user_and_host() -> str:
def session_metadata(ctx: McpContext) -> dict:
"""Generate metadata for the current session."""
return {
"session_id": ctx.local_session_id,
"system_user": user_and_host(),
@@ -56,30 +62,15 @@ def session_metadata(ctx: McpContext) -> dict:
def write_as_utf8_bytes(data: dict) -> bytes:
"""Serializes dict to bytes using UTF-8 encoding."""
return json.dumps(data).encode(UTF_8_ENCODING) + b"\n"
return json.dumps(data).encode(UTF_8) + b"\n"
def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list:
"""Deduplicate new_annotations using the annotations in the context."""
deduped_annotations = []
for annotation in new_annotations:
# Check if an annotation with the same content and address exists in ctx.annotations
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
# to remove duplicates instead of using a custom logic.
# This is a temporary solution until the Invariant SDK is updated.
is_duplicate = False
for ctx_annotation in ctx.annotations:
if (
annotation.content == ctx_annotation.content
and annotation.address == ctx_annotation.address
and annotation.extra_metadata == ctx_annotation.extra_metadata
):
is_duplicate = True
break
if not is_duplicate:
if annotation not in ctx.annotations:
deduped_annotations.append(annotation)
return deduped_annotations
@@ -94,43 +85,6 @@ def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
return False
async def get_guardrails_check_result(
ctx: McpContext,
message: dict,
action: GuardrailAction = GuardrailAction.BLOCK,
) -> dict:
"""
Check against guardrails of type action in an async manner.
"""
# Skip if no guardrails are configured for this action
if not (
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
):
return {}
# Prepare context and select appropriate guardrails
context = RequestContext.create(
request_json={},
dataset_name=ctx.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=ctx.guardrails,
)
guardrails_to_check = (
ctx.guardrails.blocking_guardrails
if action == GuardrailAction.BLOCK
else ctx.guardrails.logging_guardrails
)
# Run check_guardrails asynchronously
return await check_guardrails(
messages=ctx.trace + [message],
guardrails=guardrails_to_check,
context=context,
)
async def append_and_push_trace(
ctx: McpContext, message: dict, guardrails_result: dict
) -> None:
@@ -195,7 +149,7 @@ async def append_and_push_trace(
)
ctx.last_trace_length = len(ctx.trace)
ctx.annotations.extend(deduplicated_annotations)
except Exception as e:
except Exception as e: # pylint: disable=broad-except
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
@@ -331,17 +285,17 @@ async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
# Safely handle result object
result_obj = result.get("result", {})
if isinstance(result_obj, dict) and MCP_SERVER_INFO in result_obj:
if result_obj.get(MCP_SERVER_INFO):
ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "")
if method is None:
if not method:
return result
elif method == MCP_TOOL_CALL:
message = {
"role": "tool",
"content": result_obj.get("content"),
"error": result_obj.get("error"),
"tool_call_id": call_id
"tool_call_id": call_id,
}
# Check for blocking guardrails
guardrailing_result = await get_guardrails_check_result(
@@ -417,7 +371,7 @@ async def stream_and_forward_stdout(
try:
# Process complete JSON lines
line_str = line.decode(UTF_8_ENCODING).strip()
line_str = line.decode(UTF_8).strip()
if not line_str:
continue
@@ -431,9 +385,7 @@ async def stream_and_forward_stdout(
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
sys.stdout.buffer.flush()
except Exception as e:
import traceback
mcp_log(traceback.format_exc())
except Exception as e: # pylint: disable=broad-except
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
if line:
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
@@ -458,12 +410,13 @@ async def stream_and_forward_stderr(
async def process_line(
ctx: McpContext, mcp_process: subprocess.Popen, line: bytes
) -> None:
"""Process a line of input from stdin, decode it, and forward to mcp_process."""
if ctx.verbose:
mcp_log(f"[INFO] client -> server: {line}")
# Try to decode and parse as JSON to check for tool calls
try:
text = line.decode(UTF_8_ENCODING)
text = line.decode(UTF_8)
parsed_json = json.loads(text)
except json.JSONDecodeError as je:
mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}")
@@ -471,12 +424,10 @@ async def process_line(
return
if parsed_json.get(MCP_METHOD) is not None:
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
MCP_METHOD
)
if "params" in parsed_json and "clientInfo" in parsed_json.get("params"):
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(MCP_METHOD)
if parsed_json.get(MCP_PARAMS) and parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO):
ctx.mcp_client_name = (
parsed_json.get("params").get("clientInfo").get("name", "")
parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
# Check if this is a tool call request
@@ -506,8 +457,6 @@ async def process_line(
if parsed_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
# Refresh guardrails
run_task_sync(ctx.load_guardrails)
# mcp_message_{}
ctx.trace.append(
{
"role": "assistant",
@@ -528,14 +477,16 @@ async def process_line(
mcp_process.stdin.flush()
async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -> tuple[bytes | None, str]:
async def wait_for_stdin_input(
loop: asyncio.AbstractEventLoop, stdin_fd: int
) -> tuple[bytes | None, str]:
"""
Platform-specific implementation to wait for and read input from stdin.
Args:
loop: The asyncio event loop
stdin_fd: The file descriptor for stdin
Returns:
tuple[bytes | None, str]: A tuple containing:
- The data read from stdin or None
@@ -548,11 +499,11 @@ async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -
try:
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
if not chunk: # Empty bytes means EOF
return None, 'eof'
return chunk, 'data'
return None, STATUS_EOF
return chunk, STATUS_DATA
except (BlockingIOError, OSError):
# No data available yet
return None, 'wait'
return None, STATUS_WAIT
else:
# On Unix-like systems, use select
ready, _, _ = await loop.run_in_executor(
@@ -562,13 +513,13 @@ async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -
if not ready:
# No input available, yield to other tasks
await asyncio.sleep(0.01)
return None, 'wait'
return None, STATUS_WAIT
# Read available data
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
if not chunk: # Empty bytes means EOF
return None, 'eof'
return chunk, 'data'
return None, STATUS_EOF
return chunk, STATUS_DATA
async def run_stdio_input_loop(
@@ -589,14 +540,14 @@ async def run_stdio_input_loop(
while True:
# Get input using platform-specific method
chunk, status = await wait_for_stdin_input(loop, stdin_fd)
if status == 'eof':
if status == STATUS_EOF:
# EOF detected, break the loop
break
elif status == 'wait':
elif status == STATUS_WAIT:
# No data available yet, continue polling
continue
elif status == 'data':
elif status == STATUS_DATA:
# We got some data, process it
buffer += chunk

View File

@@ -2,41 +2,11 @@
import asyncio
import concurrent.futures
import threading
from contextlib import redirect_stdout
from typing import Any
from gateway.mcp.log import MCP_LOG_FILE, mcp_log
def run_task_in_background(async_func, *args, **kwargs):
"""
Runs an async function in a background thread with its own event loop.
This function does NOT block the calling thread as it immediately returns
after starting the background thread.
Args:
async_func: The async function to run
*args: Positional arguments to pass to the async function
**kwargs: Keyword arguments to pass to the async function
"""
def thread_target():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(async_func(*args, **kwargs))
except Exception as e:
mcp_log(
f"[ERROR] Error in async thread while running run_task_in_background: {e}"
)
finally:
loop.close()
# Create and start a daemon thread
thread = threading.Thread(target=thread_target, daemon=True)
thread.start()
from gateway.mcp.log import MCP_LOG_FILE
def run_task_sync(async_func, *args, **kwargs) -> Any:

View File

@@ -22,6 +22,7 @@ from gateway.common.constants import (
MCP_RESULT,
MCP_SERVER_INFO,
MCP_CLIENT_INFO,
UTF_8,
)
from gateway.common.guardrails import GuardrailAction
from gateway.common.mcp_sessions_manager import (
@@ -42,12 +43,12 @@ MCP_SERVER_SSE_HEADERS = {
"accept",
"cache-control",
}
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
gateway = APIRouter()
session_store = McpSessionsManager()
@gateway.post("/mcp/messages/")
@gateway.post("/mcp/sse/messages/")
async def mcp_post_gateway(
request: Request,
@@ -64,15 +65,17 @@ async def mcp_post_gateway(
status_code=400,
detail="Session does not exist",
)
if not request.headers.get("mcp-server-base-url"):
if not request.headers.get(MCP_SERVER_BASE_URL_HEADER):
return HTTPException(
status_code=400,
detail="Missing 'mcp-server-base-url' header",
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
session_id = query_params.get("session_id")
mcp_server_messages_endpoint = (
_convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url"))
_convert_localhost_to_docker_host(
request.headers.get(MCP_SERVER_BASE_URL_HEADER)
)
+ "/messages/?"
+ session_id
)
@@ -103,14 +106,12 @@ async def mcp_post_gateway(
elif request_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
# Intercept and potentially block the request
hook_tool_call_result, is_blocked = await _hook_tool_call(
session_id=session_id, request_json={
session_id=session_id,
request_json={
"id": request_json.get("id"),
"method": MCP_LIST_TOOLS,
"params": {
"name": MCP_LIST_TOOLS,
"arguments": {}
},
}
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
},
)
if is_blocked:
# Add the error message to the session.
@@ -147,18 +148,16 @@ async def mcp_post_gateway(
raise HTTPException(status_code=500, detail="Unexpected error") from e
@gateway.get("/mcp/sse")
async def mcp_get_sse_gateway(
request: Request,
) -> StreamingResponse:
"""Proxy calls to the MCP Server tools"""
mcp_server_base_url = request.headers.get("mcp-server-base-url")
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
if not mcp_server_base_url:
print("missing base url", request.headers, flush=True)
raise HTTPException(
status_code=400,
detail="Missing 'mcp-server-base-url' header",
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
mcp_server_sse_endpoint = (
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
@@ -243,7 +242,7 @@ async def mcp_get_sse_gateway(
# Pass through other event types
# pylint: disable=line-too-long
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
"utf-8"
UTF_8
)
# Put the processed event in the queue
@@ -251,7 +250,7 @@ async def mcp_get_sse_gateway(
except httpx.StreamClosed as e:
print(f"Server stream closed: {e}", flush=True)
except Exception as e:
except Exception as e: # pylint: disable=broad-except
print(f"Error processing server events: {e}", flush=True)
# Start server events processor
@@ -360,7 +359,9 @@ async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bo
return request_json, False
async def _hook_tool_call_response(session_id: str, response_json: dict, is_tools_list=False) -> dict:
async def _hook_tool_call_response(
session_id: str, response_json: dict, is_tools_list=False
) -> dict:
"""
Hook to process the response JSON after receiving it from the MCP server.
@@ -404,10 +405,10 @@ async def _hook_tool_call_response(session_id: str, response_json: dict, is_tool
else:
# special error response for tools/list tool call
result = {
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
"jsonrpc": "2.0",
"id": response_json.get("id"),
"result": {
"tools": [
{
"name": "blocked_" + tool["name"],
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
@@ -425,7 +426,7 @@ async def _hook_tool_call_response(session_id: str, response_json: dict, is_tool
}
for tool in response_json["result"]["tools"]
]
}
},
}
# Push trace to the explorer - don't block on its response
@@ -489,7 +490,7 @@ async def _handle_endpoint_event(
"/messages/?session_id=",
"/api/v1/gateway/mcp/sse/messages/?session_id=",
)
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8")
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode(UTF_8)
return event_bytes, session_id
@@ -501,7 +502,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
session_id (str): The session ID associated with the request.
sse (ServerSentEvent): The original SSE object.
"""
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8)
session = session_store.get_session(session_id)
try:
response_json = json.loads(sse.data)
@@ -525,7 +526,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
UTF_8
)
elif method == MCP_LIST_TOOLS:
# store tools in metadata
@@ -538,9 +539,11 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
response_json={
"id": response_json.get("id"),
"result": {
"content": json.dumps(response_json.get(MCP_RESULT).get("tools")),
"content": json.dumps(
response_json.get(MCP_RESULT).get("tools")
),
"tools": response_json.get(MCP_RESULT).get("tools"),
}
},
},
is_tools_list=True,
)
@@ -550,7 +553,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
# pylint: disable=line-too-long
if blocked:
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
"utf-8"
UTF_8
)
except json.JSONDecodeError as e:
@@ -591,7 +594,7 @@ async def _check_for_pending_error_messages(
for error_message in error_messages:
error_bytes = (
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
"utf-8"
UTF_8
)
)
await pending_error_messages_queue.put(error_bytes)

View File

@@ -45,7 +45,7 @@ async def test_mcp_with_gateway(
project_name,
push_to_explorer=push_to_explorer,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"}
tool_args={"username": "Alice"},
)
else:
result = await mcp_stdio_client_run(
@@ -307,7 +307,7 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
async def test_mcp_sse_with_gateway_hybrid_guardrails(
async def test_mcp_with_gateway_hybrid_guardrails(
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
):
"""Test MCP gateway and verify that logging and blocking guardrails work together"""
@@ -425,7 +425,6 @@ async def test_mcp_sse_with_gateway_hybrid_guardrails(
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
@pytest.mark.asyncio
@pytest.mark.timeout(30)
@pytest.mark.parametrize("transport", ["stdio", "sse"])
@@ -434,7 +433,7 @@ async def test_mcp_tool_list_blocking(
):
"""
Tests that blocking guardrails work for the tools/list call.
For those, the expected behavior is that the returned tools are all renamed to blocked_... and include an informative block notice, instead of the original tool description.
"""
project_name = "test-mcp-" + str(uuid.uuid4())
@@ -473,5 +472,7 @@ async def test_mcp_tool_list_blocking(
tool_args={},
)
assert "blocked_get_last_message_from_user" in str(tools_result), "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result)
assert "blocked_get_last_message_from_user" in str(tools_result), (
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
+ str(tools_result)
)