mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Small cleanups in sse and stdio MCP implementation.
This commit is contained in:
@@ -15,6 +15,7 @@ IGNORED_HEADERS = [
|
||||
CLIENT_TIMEOUT = 60.0
|
||||
|
||||
# MCP related constants
|
||||
UTF_8 = "utf-8"
|
||||
MCP_METHOD = "method"
|
||||
MCP_TOOL_CALL = "tools/call"
|
||||
MCP_LIST_TOOLS = "tools/list"
|
||||
|
||||
@@ -65,23 +65,8 @@ class McpSession(BaseModel):
|
||||
"""Deduplicate new_annotations using the annotations in the session."""
|
||||
deduped_annotations = []
|
||||
for annotation in new_annotations:
|
||||
# Check if an annotation with the same content and address exists in self.annotations
|
||||
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
|
||||
# to remove duplicates instead of using a custom logic.
|
||||
# This is a temporary solution until the Invariant SDK is updated.
|
||||
is_duplicate = False
|
||||
for current_annotation in self.annotations:
|
||||
if (
|
||||
annotation.content == current_annotation.content
|
||||
and annotation.address == current_annotation.address
|
||||
and annotation.extra_metadata == current_annotation.extra_metadata
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
if annotation not in self.annotations:
|
||||
deduped_annotations.append(annotation)
|
||||
|
||||
return deduped_annotations
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
|
||||
@@ -66,9 +66,6 @@ def create_annotations_from_guardrails_errors(
|
||||
)
|
||||
)
|
||||
# Remove duplicates
|
||||
# TODO: Rely on the __eq__ and __hash__ methods of the AnnotationCreate class
|
||||
# to remove duplicates instead of using a custom function.
|
||||
# This is a temporary solution until the Invariant SDK is updated.
|
||||
return remove_duplicates(annotations)
|
||||
|
||||
|
||||
@@ -85,7 +82,7 @@ def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCre
|
||||
for annotation in annotations:
|
||||
# Convert the entire extra_metadata dict to a JSON string
|
||||
# This creates a hashable representation regardless of nested content
|
||||
metadata_str = json.dumps(annotation.extra_metadata, sort_keys=True)
|
||||
metadata_str = json.dumps(annotation.extra_metadata or {}, sort_keys=True)
|
||||
|
||||
# Create a unique identifier using all three fields
|
||||
unique_key = (annotation.content, annotation.address, metadata_str)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import asyncio
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import asyncio
|
||||
import platform
|
||||
import select
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
@@ -21,6 +23,7 @@ from gateway.common.constants import (
|
||||
MCP_SERVER_INFO,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
@@ -28,15 +31,17 @@ from gateway.integrations.explorer import create_annotations_from_guardrails_err
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response
|
||||
from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
|
||||
import getpass
|
||||
import socket
|
||||
from gateway.mcp.task_utils import run_task_sync
|
||||
|
||||
|
||||
UTF_8_ENCODING = "utf-8"
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
STATUS_EOF = "eof"
|
||||
STATUS_DATA = "data"
|
||||
STATUS_WAIT = "wait"
|
||||
|
||||
|
||||
def user_and_host() -> str:
|
||||
"""Get the current user and hostname."""
|
||||
username = getpass.getuser()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
@@ -44,6 +49,7 @@ def user_and_host() -> str:
|
||||
|
||||
|
||||
def session_metadata(ctx: McpContext) -> dict:
|
||||
"""Generate metadata for the current session."""
|
||||
return {
|
||||
"session_id": ctx.local_session_id,
|
||||
"system_user": user_and_host(),
|
||||
@@ -56,30 +62,15 @@ def session_metadata(ctx: McpContext) -> dict:
|
||||
|
||||
def write_as_utf8_bytes(data: dict) -> bytes:
|
||||
"""Serializes dict to bytes using UTF-8 encoding."""
|
||||
return json.dumps(data).encode(UTF_8_ENCODING) + b"\n"
|
||||
return json.dumps(data).encode(UTF_8) + b"\n"
|
||||
|
||||
|
||||
def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list:
|
||||
"""Deduplicate new_annotations using the annotations in the context."""
|
||||
deduped_annotations = []
|
||||
for annotation in new_annotations:
|
||||
# Check if an annotation with the same content and address exists in ctx.annotations
|
||||
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
|
||||
# to remove duplicates instead of using a custom logic.
|
||||
# This is a temporary solution until the Invariant SDK is updated.
|
||||
is_duplicate = False
|
||||
for ctx_annotation in ctx.annotations:
|
||||
if (
|
||||
annotation.content == ctx_annotation.content
|
||||
and annotation.address == ctx_annotation.address
|
||||
and annotation.extra_metadata == ctx_annotation.extra_metadata
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
if annotation not in ctx.annotations:
|
||||
deduped_annotations.append(annotation)
|
||||
|
||||
return deduped_annotations
|
||||
|
||||
|
||||
@@ -94,43 +85,6 @@ def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
ctx: McpContext,
|
||||
message: dict,
|
||||
action: GuardrailAction = GuardrailAction.BLOCK,
|
||||
) -> dict:
|
||||
"""
|
||||
Check against guardrails of type action in an async manner.
|
||||
"""
|
||||
# Skip if no guardrails are configured for this action
|
||||
if not (
|
||||
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
|
||||
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=ctx.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=ctx.guardrails,
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
ctx.guardrails.blocking_guardrails
|
||||
if action == GuardrailAction.BLOCK
|
||||
else ctx.guardrails.logging_guardrails
|
||||
)
|
||||
|
||||
# Run check_guardrails asynchronously
|
||||
return await check_guardrails(
|
||||
messages=ctx.trace + [message],
|
||||
guardrails=guardrails_to_check,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
async def append_and_push_trace(
|
||||
ctx: McpContext, message: dict, guardrails_result: dict
|
||||
) -> None:
|
||||
@@ -195,7 +149,7 @@ async def append_and_push_trace(
|
||||
)
|
||||
ctx.last_trace_length = len(ctx.trace)
|
||||
ctx.annotations.extend(deduplicated_annotations)
|
||||
except Exception as e:
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
|
||||
|
||||
|
||||
@@ -331,17 +285,17 @@ async def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
|
||||
# Safely handle result object
|
||||
result_obj = result.get("result", {})
|
||||
if isinstance(result_obj, dict) and MCP_SERVER_INFO in result_obj:
|
||||
if result_obj.get(MCP_SERVER_INFO):
|
||||
ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "")
|
||||
|
||||
if method is None:
|
||||
if not method:
|
||||
return result
|
||||
elif method == MCP_TOOL_CALL:
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": result_obj.get("content"),
|
||||
"error": result_obj.get("error"),
|
||||
"tool_call_id": call_id
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = await get_guardrails_check_result(
|
||||
@@ -417,7 +371,7 @@ async def stream_and_forward_stdout(
|
||||
|
||||
try:
|
||||
# Process complete JSON lines
|
||||
line_str = line.decode(UTF_8_ENCODING).strip()
|
||||
line_str = line.decode(UTF_8).strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
@@ -431,9 +385,7 @@ async def stream_and_forward_stdout(
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
mcp_log(traceback.format_exc())
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
|
||||
if line:
|
||||
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
|
||||
@@ -458,12 +410,13 @@ async def stream_and_forward_stderr(
|
||||
async def process_line(
|
||||
ctx: McpContext, mcp_process: subprocess.Popen, line: bytes
|
||||
) -> None:
|
||||
"""Process a line of input from stdin, decode it, and forward to mcp_process."""
|
||||
if ctx.verbose:
|
||||
mcp_log(f"[INFO] client -> server: {line}")
|
||||
|
||||
# Try to decode and parse as JSON to check for tool calls
|
||||
try:
|
||||
text = line.decode(UTF_8_ENCODING)
|
||||
text = line.decode(UTF_8)
|
||||
parsed_json = json.loads(text)
|
||||
except json.JSONDecodeError as je:
|
||||
mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}")
|
||||
@@ -471,12 +424,10 @@ async def process_line(
|
||||
return
|
||||
|
||||
if parsed_json.get(MCP_METHOD) is not None:
|
||||
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if "params" in parsed_json and "clientInfo" in parsed_json.get("params"):
|
||||
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(MCP_METHOD)
|
||||
if parsed_json.get(MCP_PARAMS) and parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO):
|
||||
ctx.mcp_client_name = (
|
||||
parsed_json.get("params").get("clientInfo").get("name", "")
|
||||
parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
# Check if this is a tool call request
|
||||
@@ -506,8 +457,6 @@ async def process_line(
|
||||
if parsed_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
# Refresh guardrails
|
||||
run_task_sync(ctx.load_guardrails)
|
||||
|
||||
# mcp_message_{}
|
||||
ctx.trace.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -528,14 +477,16 @@ async def process_line(
|
||||
mcp_process.stdin.flush()
|
||||
|
||||
|
||||
async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -> tuple[bytes | None, str]:
|
||||
async def wait_for_stdin_input(
|
||||
loop: asyncio.AbstractEventLoop, stdin_fd: int
|
||||
) -> tuple[bytes | None, str]:
|
||||
"""
|
||||
Platform-specific implementation to wait for and read input from stdin.
|
||||
|
||||
|
||||
Args:
|
||||
loop: The asyncio event loop
|
||||
stdin_fd: The file descriptor for stdin
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[bytes | None, str]: A tuple containing:
|
||||
- The data read from stdin or None
|
||||
@@ -548,11 +499,11 @@ async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -
|
||||
try:
|
||||
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
|
||||
if not chunk: # Empty bytes means EOF
|
||||
return None, 'eof'
|
||||
return chunk, 'data'
|
||||
return None, STATUS_EOF
|
||||
return chunk, STATUS_DATA
|
||||
except (BlockingIOError, OSError):
|
||||
# No data available yet
|
||||
return None, 'wait'
|
||||
return None, STATUS_WAIT
|
||||
else:
|
||||
# On Unix-like systems, use select
|
||||
ready, _, _ = await loop.run_in_executor(
|
||||
@@ -562,13 +513,13 @@ async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -
|
||||
if not ready:
|
||||
# No input available, yield to other tasks
|
||||
await asyncio.sleep(0.01)
|
||||
return None, 'wait'
|
||||
return None, STATUS_WAIT
|
||||
|
||||
# Read available data
|
||||
chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096))
|
||||
if not chunk: # Empty bytes means EOF
|
||||
return None, 'eof'
|
||||
return chunk, 'data'
|
||||
return None, STATUS_EOF
|
||||
return chunk, STATUS_DATA
|
||||
|
||||
|
||||
async def run_stdio_input_loop(
|
||||
@@ -589,14 +540,14 @@ async def run_stdio_input_loop(
|
||||
while True:
|
||||
# Get input using platform-specific method
|
||||
chunk, status = await wait_for_stdin_input(loop, stdin_fd)
|
||||
|
||||
if status == 'eof':
|
||||
|
||||
if status == STATUS_EOF:
|
||||
# EOF detected, break the loop
|
||||
break
|
||||
elif status == 'wait':
|
||||
elif status == STATUS_WAIT:
|
||||
# No data available yet, continue polling
|
||||
continue
|
||||
elif status == 'data':
|
||||
elif status == STATUS_DATA:
|
||||
# We got some data, process it
|
||||
buffer += chunk
|
||||
|
||||
|
||||
@@ -2,41 +2,11 @@
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any
|
||||
|
||||
from gateway.mcp.log import MCP_LOG_FILE, mcp_log
|
||||
|
||||
|
||||
def run_task_in_background(async_func, *args, **kwargs):
|
||||
"""
|
||||
Runs an async function in a background thread with its own event loop.
|
||||
This function does NOT block the calling thread as it immediately returns
|
||||
after starting the background thread.
|
||||
|
||||
Args:
|
||||
async_func: The async function to run
|
||||
*args: Positional arguments to pass to the async function
|
||||
**kwargs: Keyword arguments to pass to the async function
|
||||
"""
|
||||
|
||||
def thread_target():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(async_func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
mcp_log(
|
||||
f"[ERROR] Error in async thread while running run_task_in_background: {e}"
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Create and start a daemon thread
|
||||
thread = threading.Thread(target=thread_target, daemon=True)
|
||||
thread.start()
|
||||
from gateway.mcp.log import MCP_LOG_FILE
|
||||
|
||||
|
||||
def run_task_sync(async_func, *args, **kwargs) -> Any:
|
||||
|
||||
@@ -22,6 +22,7 @@ from gateway.common.constants import (
|
||||
MCP_RESULT,
|
||||
MCP_SERVER_INFO,
|
||||
MCP_CLIENT_INFO,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
@@ -42,12 +43,12 @@ MCP_SERVER_SSE_HEADERS = {
|
||||
"accept",
|
||||
"cache-control",
|
||||
}
|
||||
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
|
||||
|
||||
gateway = APIRouter()
|
||||
session_store = McpSessionsManager()
|
||||
|
||||
|
||||
@gateway.post("/mcp/messages/")
|
||||
@gateway.post("/mcp/sse/messages/")
|
||||
async def mcp_post_gateway(
|
||||
request: Request,
|
||||
@@ -64,15 +65,17 @@ async def mcp_post_gateway(
|
||||
status_code=400,
|
||||
detail="Session does not exist",
|
||||
)
|
||||
if not request.headers.get("mcp-server-base-url"):
|
||||
if not request.headers.get(MCP_SERVER_BASE_URL_HEADER):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
|
||||
session_id = query_params.get("session_id")
|
||||
mcp_server_messages_endpoint = (
|
||||
_convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url"))
|
||||
_convert_localhost_to_docker_host(
|
||||
request.headers.get(MCP_SERVER_BASE_URL_HEADER)
|
||||
)
|
||||
+ "/messages/?"
|
||||
+ session_id
|
||||
)
|
||||
@@ -103,14 +106,12 @@ async def mcp_post_gateway(
|
||||
elif request_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
# Intercept and potentially block the request
|
||||
hook_tool_call_result, is_blocked = await _hook_tool_call(
|
||||
session_id=session_id, request_json={
|
||||
session_id=session_id,
|
||||
request_json={
|
||||
"id": request_json.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
"params": {
|
||||
"name": MCP_LIST_TOOLS,
|
||||
"arguments": {}
|
||||
},
|
||||
}
|
||||
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
|
||||
},
|
||||
)
|
||||
if is_blocked:
|
||||
# Add the error message to the session.
|
||||
@@ -147,18 +148,16 @@ async def mcp_post_gateway(
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
|
||||
|
||||
@gateway.get("/mcp/sse")
|
||||
async def mcp_get_sse_gateway(
|
||||
request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Proxy calls to the MCP Server tools"""
|
||||
mcp_server_base_url = request.headers.get("mcp-server-base-url")
|
||||
mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER)
|
||||
if not mcp_server_base_url:
|
||||
print("missing base url", request.headers, flush=True)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Missing 'mcp-server-base-url' header",
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
mcp_server_sse_endpoint = (
|
||||
_convert_localhost_to_docker_host(mcp_server_base_url) + "/sse"
|
||||
@@ -243,7 +242,7 @@ async def mcp_get_sse_gateway(
|
||||
# Pass through other event types
|
||||
# pylint: disable=line-too-long
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(
|
||||
"utf-8"
|
||||
UTF_8
|
||||
)
|
||||
|
||||
# Put the processed event in the queue
|
||||
@@ -251,7 +250,7 @@ async def mcp_get_sse_gateway(
|
||||
|
||||
except httpx.StreamClosed as e:
|
||||
print(f"Server stream closed: {e}", flush=True)
|
||||
except Exception as e:
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Error processing server events: {e}", flush=True)
|
||||
|
||||
# Start server events processor
|
||||
@@ -360,7 +359,9 @@ async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bo
|
||||
return request_json, False
|
||||
|
||||
|
||||
async def _hook_tool_call_response(session_id: str, response_json: dict, is_tools_list=False) -> dict:
|
||||
async def _hook_tool_call_response(
|
||||
session_id: str, response_json: dict, is_tools_list=False
|
||||
) -> dict:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
@@ -404,10 +405,10 @@ async def _hook_tool_call_response(session_id: str, response_json: dict, is_tool
|
||||
else:
|
||||
# special error response for tools/list tool call
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "blocked_" + tool["name"],
|
||||
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
|
||||
@@ -425,7 +426,7 @@ async def _hook_tool_call_response(session_id: str, response_json: dict, is_tool
|
||||
}
|
||||
for tool in response_json["result"]["tools"]
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Push trace to the explorer - don't block on its response
|
||||
@@ -489,7 +490,7 @@ async def _handle_endpoint_event(
|
||||
"/messages/?session_id=",
|
||||
"/api/v1/gateway/mcp/sse/messages/?session_id=",
|
||||
)
|
||||
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8")
|
||||
event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode(UTF_8)
|
||||
return event_bytes, session_id
|
||||
|
||||
|
||||
@@ -501,7 +502,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
session_id (str): The session ID associated with the request.
|
||||
sse (ServerSentEvent): The original SSE object.
|
||||
"""
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8")
|
||||
event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8)
|
||||
session = session_store.get_session(session_id)
|
||||
try:
|
||||
response_json = json.loads(sse.data)
|
||||
@@ -525,7 +526,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
UTF_8
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
@@ -538,9 +539,11 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
response_json={
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"content": json.dumps(response_json.get(MCP_RESULT).get("tools")),
|
||||
"content": json.dumps(
|
||||
response_json.get(MCP_RESULT).get("tools")
|
||||
),
|
||||
"tools": response_json.get(MCP_RESULT).get("tools"),
|
||||
}
|
||||
},
|
||||
},
|
||||
is_tools_list=True,
|
||||
)
|
||||
@@ -550,7 +553,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
"utf-8"
|
||||
UTF_8
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -591,7 +594,7 @@ async def _check_for_pending_error_messages(
|
||||
for error_message in error_messages:
|
||||
error_bytes = (
|
||||
f"event: message\ndata: {json.dumps(error_message)}\n\n".encode(
|
||||
"utf-8"
|
||||
UTF_8
|
||||
)
|
||||
)
|
||||
await pending_error_messages_queue.put(error_bytes)
|
||||
|
||||
@@ -45,7 +45,7 @@ async def test_mcp_with_gateway(
|
||||
project_name,
|
||||
push_to_explorer=push_to_explorer,
|
||||
tool_name="get_last_message_from_user",
|
||||
tool_args={"username": "Alice"}
|
||||
tool_args={"username": "Alice"},
|
||||
)
|
||||
else:
|
||||
result = await mcp_stdio_client_run(
|
||||
@@ -307,7 +307,7 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
async def test_mcp_sse_with_gateway_hybrid_guardrails(
|
||||
async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport
|
||||
):
|
||||
"""Test MCP gateway and verify that logging and blocking guardrails work together"""
|
||||
@@ -425,7 +425,6 @@ async def test_mcp_sse_with_gateway_hybrid_guardrails(
|
||||
assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("transport", ["stdio", "sse"])
|
||||
@@ -434,7 +433,7 @@ async def test_mcp_tool_list_blocking(
|
||||
):
|
||||
"""
|
||||
Tests that blocking guardrails work for the tools/list call.
|
||||
|
||||
|
||||
For those, the expected behavior is that the returned tools are all renamed to blocked_... and include an informative block notice, instead of the original tool description.
|
||||
"""
|
||||
project_name = "test-mcp-" + str(uuid.uuid4())
|
||||
@@ -473,5 +472,7 @@ async def test_mcp_tool_list_blocking(
|
||||
tool_args={},
|
||||
)
|
||||
|
||||
assert "blocked_get_last_message_from_user" in str(tools_result), "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result)
|
||||
|
||||
assert "blocked_get_last_message_from_user" in str(tools_result), (
|
||||
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
|
||||
+ str(tools_result)
|
||||
)
|
||||
Reference in New Issue
Block a user