diff --git a/gateway/common/constants.py b/gateway/common/constants.py index 077d2be..715c42a 100644 --- a/gateway/common/constants.py +++ b/gateway/common/constants.py @@ -15,6 +15,7 @@ IGNORED_HEADERS = [ CLIENT_TIMEOUT = 60.0 # MCP related constants +UTF_8 = "utf-8" MCP_METHOD = "method" MCP_TOOL_CALL = "tools/call" MCP_LIST_TOOLS = "tools/list" diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py index 359f3c4..3cf2ef0 100644 --- a/gateway/common/mcp_sessions_manager.py +++ b/gateway/common/mcp_sessions_manager.py @@ -65,23 +65,8 @@ class McpSession(BaseModel): """Deduplicate new_annotations using the annotations in the session.""" deduped_annotations = [] for annotation in new_annotations: - # Check if an annotation with the same content and address exists in self.annotations - # TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in - # to remove duplicates instead of using a custom logic. - # This is a temporary solution until the Invariant SDK is updated. - is_duplicate = False - for current_annotation in self.annotations: - if ( - annotation.content == current_annotation.content - and annotation.address == current_annotation.address - and annotation.extra_metadata == current_annotation.extra_metadata - ): - is_duplicate = True - break - - if not is_duplicate: + if annotation not in self.annotations: deduped_annotations.append(annotation) - return deduped_annotations @contextlib.asynccontextmanager diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index 4ab4a0c..e2b2a77 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -66,9 +66,6 @@ def create_annotations_from_guardrails_errors( ) ) # Remove duplicates - # TODO: Rely on the __eq__ and __hash__ methods of the AnnotationCreate class - # to remove duplicates instead of using a custom function. - # This is a temporary solution until the Invariant SDK is updated. return remove_duplicates(annotations) @@ -85,7 +82,7 @@ def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCre for annotation in annotations: # Convert the entire extra_metadata dict to a JSON string # This creates a hashable representation regardless of nested content - metadata_str = json.dumps(annotation.extra_metadata, sort_keys=True) + metadata_str = json.dumps(annotation.extra_metadata or {}, sort_keys=True) # Create a unique identifier using all three fields unique_key = (annotation.content, annotation.address, metadata_str) diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py index 4c55a8c..865d11c 100644 --- a/gateway/mcp/mcp.py +++ b/gateway/mcp/mcp.py @@ -1,12 +1,14 @@ """Gateway for MCP (Model Context Protocol) integration with Invariant.""" -import sys -import subprocess +import asyncio +import getpass import json import os -import select -import asyncio import platform +import select +import socket +import subprocess +import sys from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.append_messages import AppendMessagesRequest @@ -21,6 +23,7 @@ from gateway.common.constants import ( MCP_SERVER_INFO, MCP_TOOL_CALL, MCP_LIST_TOOLS, + UTF_8, ) from gateway.common.guardrails import GuardrailAction from gateway.common.request_context import RequestContext @@ -28,15 +31,17 @@ from gateway.integrations.explorer import create_annotations_from_guardrails_err from gateway.integrations.guardrails import check_guardrails from gateway.mcp.log import mcp_log, MCP_LOG_FILE, format_errors_in_response from gateway.mcp.mcp_context import McpContext -from gateway.mcp.task_utils import run_task_in_background, run_task_sync -import getpass -import socket +from gateway.mcp.task_utils import run_task_sync + -UTF_8_ENCODING = "utf-8" DEFAULT_API_URL = "https://explorer.invariantlabs.ai" +STATUS_EOF = "eof" +STATUS_DATA = "data" +STATUS_WAIT = "wait" def user_and_host() -> str: + """Get the current user and hostname.""" username = getpass.getuser() hostname = socket.gethostname() @@ -44,6 +49,7 @@ def user_and_host() -> str: def session_metadata(ctx: McpContext) -> dict: + """Generate metadata for the current session.""" return { "session_id": ctx.local_session_id, "system_user": user_and_host(), @@ -56,30 +62,15 @@ def session_metadata(ctx: McpContext) -> dict: def write_as_utf8_bytes(data: dict) -> bytes: """Serializes dict to bytes using UTF-8 encoding.""" - return json.dumps(data).encode(UTF_8_ENCODING) + b"\n" + return json.dumps(data).encode(UTF_8) + b"\n" def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list: """Deduplicate new_annotations using the annotations in the context.""" deduped_annotations = [] for annotation in new_annotations: - # Check if an annotation with the same content and address exists in ctx.annotations - # TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in - # to remove duplicates instead of using a custom logic. - # This is a temporary solution until the Invariant SDK is updated. - is_duplicate = False - for ctx_annotation in ctx.annotations: - if ( - annotation.content == ctx_annotation.content - and annotation.address == ctx_annotation.address - and annotation.extra_metadata == ctx_annotation.extra_metadata - ): - is_duplicate = True - break - - if not is_duplicate: + if annotation not in ctx.annotations: deduped_annotations.append(annotation) - return deduped_annotations @@ -94,43 +85,6 @@ def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool: return False -async def get_guardrails_check_result( - ctx: McpContext, - message: dict, - action: GuardrailAction = GuardrailAction.BLOCK, -) -> dict: - """ - Check against guardrails of type action in an async manner. - """ - # Skip if no guardrails are configured for this action - if not ( - (ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK) - or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG) - ): - return {} - - # Prepare context and select appropriate guardrails - context = RequestContext.create( - request_json={}, - dataset_name=ctx.explorer_dataset, - invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), - guardrails=ctx.guardrails, - ) - - guardrails_to_check = ( - ctx.guardrails.blocking_guardrails - if action == GuardrailAction.BLOCK - else ctx.guardrails.logging_guardrails - ) - - # Run check_guardrails asynchronously - return await check_guardrails( - messages=ctx.trace + [message], - guardrails=guardrails_to_check, - context=context, - ) - - async def append_and_push_trace( ctx: McpContext, message: dict, guardrails_result: dict ) -> None: @@ -195,7 +149,7 @@ async def append_and_push_trace( ) ctx.last_trace_length = len(ctx.trace) ctx.annotations.extend(deduplicated_annotations) - except Exception as e: + except Exception as e: # pylint: disable=broad-except mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e) @@ -331,17 +285,17 @@ async def hook_tool_result(ctx: McpContext, result: dict) -> dict: # Safely handle result object result_obj = result.get("result", {}) - if isinstance(result_obj, dict) and MCP_SERVER_INFO in result_obj: + if result_obj.get(MCP_SERVER_INFO): ctx.mcp_server_name = result_obj.get(MCP_SERVER_INFO, {}).get("name", "") - if method is None: + if not method: return result elif method == MCP_TOOL_CALL: message = { "role": "tool", "content": result_obj.get("content"), "error": result_obj.get("error"), - "tool_call_id": call_id + "tool_call_id": call_id, } # Check for blocking guardrails guardrailing_result = await get_guardrails_check_result( @@ -417,7 +371,7 @@ async def stream_and_forward_stdout( try: # Process complete JSON lines - line_str = line.decode(UTF_8_ENCODING).strip() + line_str = line.decode(UTF_8).strip() if not line_str: continue @@ -431,9 +385,7 @@ async def stream_and_forward_stdout( sys.stdout.buffer.write(write_as_utf8_bytes(processed_json)) sys.stdout.buffer.flush() - except Exception as e: - import traceback - mcp_log(traceback.format_exc()) + except Exception as e: # pylint: disable=broad-except mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}") if line: mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...") @@ -458,12 +410,13 @@ async def stream_and_forward_stderr( async def process_line( ctx: McpContext, mcp_process: subprocess.Popen, line: bytes ) -> None: + """Process a line of input from stdin, decode it, and forward to mcp_process.""" if ctx.verbose: mcp_log(f"[INFO] client -> server: {line}") # Try to decode and parse as JSON to check for tool calls try: - text = line.decode(UTF_8_ENCODING) + text = line.decode(UTF_8) parsed_json = json.loads(text) except json.JSONDecodeError as je: mcp_log(f"[ERROR] JSON decode error in run_stdio_input_loop: {str(je)}") @@ -471,12 +424,10 @@ async def process_line( return if parsed_json.get(MCP_METHOD) is not None: - ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get( - MCP_METHOD - ) - if "params" in parsed_json and "clientInfo" in parsed_json.get("params"): + ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(MCP_METHOD) + if parsed_json.get(MCP_PARAMS) and parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO): ctx.mcp_client_name = ( - parsed_json.get("params").get("clientInfo").get("name", "") + parsed_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") ) # Check if this is a tool call request @@ -506,8 +457,6 @@ async def process_line( if parsed_json.get(MCP_METHOD) == MCP_LIST_TOOLS: # Refresh guardrails run_task_sync(ctx.load_guardrails) - - # mcp_message_{} ctx.trace.append( { "role": "assistant", @@ -528,14 +477,16 @@ async def process_line( mcp_process.stdin.flush() -async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) -> tuple[bytes | None, str]: +async def wait_for_stdin_input( + loop: asyncio.AbstractEventLoop, stdin_fd: int +) -> tuple[bytes | None, str]: """ Platform-specific implementation to wait for and read input from stdin. - + Args: loop: The asyncio event loop stdin_fd: The file descriptor for stdin - + Returns: tuple[bytes | None, str]: A tuple containing: - The data read from stdin or None @@ -548,11 +499,11 @@ async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) - try: chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096)) if not chunk: # Empty bytes means EOF - return None, 'eof' - return chunk, 'data' + return None, STATUS_EOF + return chunk, STATUS_DATA except (BlockingIOError, OSError): # No data available yet - return None, 'wait' + return None, STATUS_WAIT else: # On Unix-like systems, use select ready, _, _ = await loop.run_in_executor( @@ -562,13 +513,13 @@ async def wait_for_stdin_input(loop: asyncio.AbstractEventLoop, stdin_fd: int) - if not ready: # No input available, yield to other tasks await asyncio.sleep(0.01) - return None, 'wait' + return None, STATUS_WAIT # Read available data chunk = await loop.run_in_executor(None, lambda: os.read(stdin_fd, 4096)) if not chunk: # Empty bytes means EOF - return None, 'eof' - return chunk, 'data' + return None, STATUS_EOF + return chunk, STATUS_DATA async def run_stdio_input_loop( @@ -589,14 +540,14 @@ async def run_stdio_input_loop( while True: # Get input using platform-specific method chunk, status = await wait_for_stdin_input(loop, stdin_fd) - - if status == 'eof': + + if status == STATUS_EOF: # EOF detected, break the loop break - elif status == 'wait': + elif status == STATUS_WAIT: # No data available yet, continue polling continue - elif status == 'data': + elif status == STATUS_DATA: # We got some data, process it buffer += chunk diff --git a/gateway/mcp/task_utils.py b/gateway/mcp/task_utils.py index b507d0c..2ff8b87 100644 --- a/gateway/mcp/task_utils.py +++ b/gateway/mcp/task_utils.py @@ -2,41 +2,11 @@ import asyncio import concurrent.futures -import threading from contextlib import redirect_stdout from typing import Any -from gateway.mcp.log import MCP_LOG_FILE, mcp_log - - -def run_task_in_background(async_func, *args, **kwargs): - """ - Runs an async function in a background thread with its own event loop. - This function does NOT block the calling thread as it immediately returns - after starting the background thread. - - Args: - async_func: The async function to run - *args: Positional arguments to pass to the async function - **kwargs: Keyword arguments to pass to the async function - """ - - def thread_target(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(async_func(*args, **kwargs)) - except Exception as e: - mcp_log( - f"[ERROR] Error in async thread while running run_task_in_background: {e}" - ) - finally: - loop.close() - - # Create and start a daemon thread - thread = threading.Thread(target=thread_target, daemon=True) - thread.start() +from gateway.mcp.log import MCP_LOG_FILE def run_task_sync(async_func, *args, **kwargs) -> Any: diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index 163d7e8..12e56fd 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -22,6 +22,7 @@ from gateway.common.constants import ( MCP_RESULT, MCP_SERVER_INFO, MCP_CLIENT_INFO, + UTF_8, ) from gateway.common.guardrails import GuardrailAction from gateway.common.mcp_sessions_manager import ( @@ -42,12 +43,12 @@ MCP_SERVER_SSE_HEADERS = { "accept", "cache-control", } +MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url" gateway = APIRouter() session_store = McpSessionsManager() -@gateway.post("/mcp/messages/") @gateway.post("/mcp/sse/messages/") async def mcp_post_gateway( request: Request, @@ -64,15 +65,17 @@ async def mcp_post_gateway( status_code=400, detail="Session does not exist", ) - if not request.headers.get("mcp-server-base-url"): + if not request.headers.get(MCP_SERVER_BASE_URL_HEADER): return HTTPException( status_code=400, - detail="Missing 'mcp-server-base-url' header", + detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", ) session_id = query_params.get("session_id") mcp_server_messages_endpoint = ( - _convert_localhost_to_docker_host(request.headers.get("mcp-server-base-url")) + _convert_localhost_to_docker_host( + request.headers.get(MCP_SERVER_BASE_URL_HEADER) + ) + "/messages/?" + session_id ) @@ -103,14 +106,12 @@ async def mcp_post_gateway( elif request_json.get(MCP_METHOD) == MCP_LIST_TOOLS: # Intercept and potentially block the request hook_tool_call_result, is_blocked = await _hook_tool_call( - session_id=session_id, request_json={ + session_id=session_id, + request_json={ "id": request_json.get("id"), "method": MCP_LIST_TOOLS, - "params": { - "name": MCP_LIST_TOOLS, - "arguments": {} - }, - } + "params": {"name": MCP_LIST_TOOLS, "arguments": {}}, + }, ) if is_blocked: # Add the error message to the session. @@ -147,18 +148,16 @@ async def mcp_post_gateway( raise HTTPException(status_code=500, detail="Unexpected error") from e - @gateway.get("/mcp/sse") async def mcp_get_sse_gateway( request: Request, ) -> StreamingResponse: """Proxy calls to the MCP Server tools""" - mcp_server_base_url = request.headers.get("mcp-server-base-url") + mcp_server_base_url = request.headers.get(MCP_SERVER_BASE_URL_HEADER) if not mcp_server_base_url: - print("missing base url", request.headers, flush=True) raise HTTPException( status_code=400, - detail="Missing 'mcp-server-base-url' header", + detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", ) mcp_server_sse_endpoint = ( _convert_localhost_to_docker_host(mcp_server_base_url) + "/sse" @@ -243,7 +242,7 @@ async def mcp_get_sse_gateway( # Pass through other event types # pylint: disable=line-too-long event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode( - "utf-8" + UTF_8 ) # Put the processed event in the queue @@ -251,7 +250,7 @@ async def mcp_get_sse_gateway( except httpx.StreamClosed as e: print(f"Server stream closed: {e}", flush=True) - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Error processing server events: {e}", flush=True) # Start server events processor @@ -360,7 +359,9 @@ async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bo return request_json, False -async def _hook_tool_call_response(session_id: str, response_json: dict, is_tools_list=False) -> dict: +async def _hook_tool_call_response( + session_id: str, response_json: dict, is_tools_list=False +) -> dict: """ Hook to process the response JSON after receiving it from the MCP server. @@ -404,10 +405,10 @@ async def _hook_tool_call_response(session_id: str, response_json: dict, is_tool else: # special error response for tools/list tool call result = { - "jsonrpc": "2.0", - "id": response_json.get("id"), - "result": { - "tools": [ + "jsonrpc": "2.0", + "id": response_json.get("id"), + "result": { + "tools": [ { "name": "blocked_" + tool["name"], "description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE @@ -425,7 +426,7 @@ async def _hook_tool_call_response(session_id: str, response_json: dict, is_tool } for tool in response_json["result"]["tools"] ] - } + }, } # Push trace to the explorer - don't block on its response @@ -489,7 +490,7 @@ async def _handle_endpoint_event( "/messages/?session_id=", "/api/v1/gateway/mcp/sse/messages/?session_id=", ) - event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode("utf-8") + event_bytes = f"event: {sse.event}\ndata: {modified_data}\n\n".encode(UTF_8) return event_bytes, session_id @@ -501,7 +502,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: session_id (str): The session ID associated with the request. sse (ServerSentEvent): The original SSE object. """ - event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode("utf-8") + event_bytes = f"event: {sse.event}\ndata: {sse.data}\n\n".encode(UTF_8) session = session_store.get_session(session_id) try: response_json = json.loads(sse.data) @@ -525,7 +526,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: # pylint: disable=line-too-long if blocked: event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode( - "utf-8" + UTF_8 ) elif method == MCP_LIST_TOOLS: # store tools in metadata @@ -538,9 +539,11 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: response_json={ "id": response_json.get("id"), "result": { - "content": json.dumps(response_json.get(MCP_RESULT).get("tools")), + "content": json.dumps( + response_json.get(MCP_RESULT).get("tools") + ), "tools": response_json.get(MCP_RESULT).get("tools"), - } + }, }, is_tools_list=True, ) @@ -550,7 +553,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: # pylint: disable=line-too-long if blocked: event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode( - "utf-8" + UTF_8 ) except json.JSONDecodeError as e: @@ -591,7 +594,7 @@ async def _check_for_pending_error_messages( for error_message in error_messages: error_bytes = ( f"event: message\ndata: {json.dumps(error_message)}\n\n".encode( - "utf-8" + UTF_8 ) ) await pending_error_messages_queue.put(error_bytes) diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index 13d3dc4..758ac38 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -45,7 +45,7 @@ async def test_mcp_with_gateway( project_name, push_to_explorer=push_to_explorer, tool_name="get_last_message_from_user", - tool_args={"username": "Alice"} + tool_args={"username": "Alice"}, ) else: result = await mcp_stdio_client_run( @@ -307,7 +307,7 @@ async def test_mcp_with_gateway_and_blocking_guardrails( @pytest.mark.asyncio @pytest.mark.timeout(30) @pytest.mark.parametrize("transport", ["stdio", "sse"]) -async def test_mcp_sse_with_gateway_hybrid_guardrails( +async def test_mcp_with_gateway_hybrid_guardrails( explorer_api_url, invariant_gateway_package_whl_file, gateway_url, transport ): """Test MCP gateway and verify that logging and blocking guardrails work together""" @@ -425,7 +425,6 @@ async def test_mcp_sse_with_gateway_hybrid_guardrails( assert tool_call_annotation["extra_metadata"]["guardrail"]["action"] == "log" - @pytest.mark.asyncio @pytest.mark.timeout(30) @pytest.mark.parametrize("transport", ["stdio", "sse"]) @@ -434,7 +433,7 @@ async def test_mcp_tool_list_blocking( ): """ Tests that blocking guardrails work for the tools/list call. - + For those, the expected behavior is that the returned tools are all renamed to blocked_... and include an informative block notice, instead of the original tool description. """ project_name = "test-mcp-" + str(uuid.uuid4()) @@ -473,5 +472,7 @@ async def test_mcp_tool_list_blocking( tool_args={}, ) - assert "blocked_get_last_message_from_user" in str(tools_result), "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result) - + assert "blocked_get_last_message_from_user" in str(tools_result), ( + "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + + str(tools_result) + ) \ No newline at end of file