diff --git a/gateway/__main__.py b/gateway/__main__.py index b8c3961..ee4c7b9 100644 --- a/gateway/__main__.py +++ b/gateway/__main__.py @@ -1,34 +1,52 @@ """Script is used to run actions using the Invariant Gateway.""" +import asyncio +import signal import sys from gateway.mcp import mcp -def main(): - """Entry point for the Invariant Gateway.""" +# Handle signals to ensure clean shutdown +def signal_handler(sig, frame): + """Handle signals for graceful shutdown.""" + sys.exit(0) + + +def print_help(): + """Prints the help message.""" actions = { - "mcp": "Runs the Invariant Gateway against MCP (Model Context Protocol) servers with guardrailing and push to Explorer features", - "llm": "Runs the Invariant Gateway against LLM providers with guardrailing and push to Explorer features", - "help": "Shows this help message", + "mcp": "Runs the Invariant Gateway against MCP (Model Context Protocol) servers with guardrailing and push to Explorer features.", + "llm": "Runs the Invariant Gateway against LLM providers with guardrailing and push to Explorer features. Not implemented yet.", + "help": "Shows this help message.", } - def _help(): - """_prints the help message.""" - for verb, description in actions.items(): - print(f"{verb}: {description}") + for verb, description in actions.items(): + print(f"{verb}: {description}") + + +def main(): + """Entry point for the Invariant Gateway.""" + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) if len(sys.argv) < 2: - _help() + print_help() sys.exit(1) verb = sys.argv[1] if verb == "mcp": - return mcp.execute(sys.argv[2:]) + print("[MCP] Running Invariant Gateway against MCP servers...") + return asyncio.run(mcp.execute(sys.argv[2:])) if verb == "llm": + print("[gateway/__main__.py] 'llm' action is not implemented yet.") return 1 if verb == "help": - _help() + print_help() return 0 print(f"[gateway/__main__.py] Unknown action: {verb}") return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 9499ac8..98d83e5 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -93,7 +93,9 @@ class GatewayConfigManager: return local_config -async def GuardrailsInHeader(request: fastapi.Request) -> Optional[GuardrailRuleSet]: +async def extract_guardrails_from_header( + request: fastapi.Request, +) -> Optional[GuardrailRuleSet]: """ Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding GuardrailRuleSet. If no guardrails are provided, returns None. diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index 85f820d..dfd8bee 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -1,8 +1,9 @@ """Utility functions for the Invariant explorer.""" import os -from typing import Any, Dict, List +import json +from typing import Any, Dict, List from fastapi import HTTPException from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction @@ -16,17 +17,20 @@ DEFAULT_API_URL = "https://explorer.invariantlabs.ai" def create_annotations_from_guardrails_errors( - guardrails_errors: List[dict], action: str = "block" + guardrails_errors: List[dict], ) -> List[AnnotationCreate]: """Create Explorer annotations from the guardrails errors.""" annotations = [] - def _remove_prefixes(ranges: list[str]) -> list[str]: + def _pick_most_specific_ranges(ranges: list[str]) -> list[str]: """ - Remove prefixes from the list of ranges. + Remove redundant prefixes from the list of ranges. If the ranges are ['messages.2', 'messages.2.content:25-30', 'messages.2.content'] then this returns ['messages.2.content:25-30']. + + This picks the most specific subset of the ranges and removes the rest. If some + range is a proper prefix of another range, it is removed. """ ranges = sorted(ranges, key=len) result = [] @@ -44,7 +48,7 @@ def create_annotations_from_guardrails_errors( for error in guardrails_errors: content = error.get("args")[0] - filtered_ranges = _remove_prefixes(list(error.get("ranges", []))) + filtered_ranges = _pick_most_specific_ranges(list(error.get("ranges", []))) for r in filtered_ranges: annotations.append( AnnotationCreate( @@ -61,10 +65,40 @@ def create_annotations_from_guardrails_errors( }, ) ) - return annotations + # Remove duplicates + # TODO: Rely on the __eq__ and __hash__ methods of the AnnotationCreate class + # to remove duplicates instead of using a custom function. + # This is a temporary solution until the Invariant SDK is updated. + return remove_duplicates(annotations) + + +def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]: + """ + Remove duplicate annotations based on content, address, and extra_metadata. + + Two annotations are considered duplicates if they have the same content, + address, and extra_metadata. + """ + unique_annotations = [] + seen = set() + + for annotation in annotations: + # Convert the entire extra_metadata dict to a JSON string + # This creates a hashable representation regardless of nested content + metadata_str = json.dumps(annotation.extra_metadata, sort_keys=True) + + # Create a unique identifier using all three fields + unique_key = (annotation.content, annotation.address, metadata_str) + + if unique_key not in seen: + seen.add(unique_key) + unique_annotations.append(annotation) + + return unique_annotations def get_explorer_api_url() -> str: + """Get the Invariant Explorer API URL from the environment variable.""" return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL) diff --git a/gateway/mcp/README.md b/gateway/mcp/README.md new file mode 100644 index 0000000..d272e6b --- /dev/null +++ b/gateway/mcp/README.md @@ -0,0 +1,63 @@ +This is a work in progress implementation for MCP (Model Context Protocol) with the Gateway. + +This repository will be pushed to PyPi and then using it with the MCP config file will be simpler. + +For now if the original MCP config file looks like: + +``` +{ + "mcpServers": { + "weather": { + "command": "uv", + "args": [ + "--directory", + "/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather", + "run", + "weather.py" + ] + } + } +} +``` + +You need to: + +1. Checkout the invariant-gatway repo. +2. Run `python -m build`. This will generate a .whl file in dist. +3. Modify the MCP config like this: + +``` + { + "mcpServers": { + "weather": { + "command": "uvx", + "args": [ + "--refresh", + "--from", + "/ABSOLUTE/PATH/TO/INVARIANT_GATEWAY_REPO/dist/invariant_gateway-0.0.1-py3-none-any.whl", + "invariant-gateway", + "mcp", + "--dataset-name", + "weather-testing", + "--push-explorer", + "--exec", + "uv", + "--directory", + "/Users/hemang/Sdk/mcp/weather", + "run", + "weather.py" + ], + "env": { + "INVARIANT_API_KEY": "" + } + } + } + } +``` + +This moves the original `command` and `args` to the `args` list after the `--exec` flag. + +All args before the `--exec` flag are relevant to the Invariant MCP gateway. These include: + +- `--dataset-name`: With this you can specify the name of the dataset. The guardrails are pulled from this dataset. +- `--push-explorer`: With this you can specify if you want to push the annotated traces to the Invariant Explorer. diff --git a/gateway/mcp/log.py b/gateway/mcp/log.py new file mode 100644 index 0000000..667119d --- /dev/null +++ b/gateway/mcp/log.py @@ -0,0 +1,19 @@ +"""Cusym log configuration.""" + +import os +import sys + +from builtins import print as builtins_print + +os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True) +MCP_LOG_FILE = open( + os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log"), + "a", + buffering=1, +) +sys.stderr = MCP_LOG_FILE + + +def mcp_log(*args, **kwargs) -> None: + """Custom print function to redirect output to log_out.""" + builtins_print(*args, **kwargs, file=MCP_LOG_FILE, flush=True) diff --git a/gateway/mcp/mcp.py b/gateway/mcp/mcp.py index 015d8fd..6d268ef 100644 --- a/gateway/mcp/mcp.py +++ b/gateway/mcp/mcp.py @@ -1,99 +1,181 @@ """Gateway for MCP (Model Context Protocol) integration with Invariant.""" -import argparse -import asyncio import sys import subprocess import json import os import threading -import signal -from builtins import print as builtins_print -from contextlib import redirect_stdout +from invariant_sdk.async_client import AsyncClient +from invariant_sdk.types.append_messages import AppendMessagesRequest +from invariant_sdk.types.push_traces import PushTracesRequest +from gateway.common.guardrails import GuardrailAction from gateway.common.request_context import RequestContext +from gateway.integrations.explorer import create_annotations_from_guardrails_errors from gateway.integrations.guardrails import check_guardrails -from gateway.integrations.explorer import ( - fetch_guardrails_from_explorer, -) +from gateway.mcp.log import mcp_log, MCP_LOG_FILE from gateway.mcp.mcp_context import McpContext +from gateway.mcp.task_utils import run_task_in_background, run_task_sync + +MCP_METHOD = "method" +UTF_8_ENCODING = "utf-8" +MCP_TOOL_CALL = "tools/call" +MCP_LIST_TOOLS = "tools/list" +INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """ + [Invariant Guardrails] The MCP tool call was blocked for security reasons. + Do not attempt to circumvent this block, rather explain to the user based + on the following output what went wrong: %s + """ -def custom_print(ctx, *args, **kwargs): - """Custom print function to redirect output to log_out.""" - builtins_print(*args, **kwargs, file=ctx.log_out, flush=True) +def write_as_utf8_bytes(data: dict) -> bytes: + """Serializes dict to bytes using UTF-8 encoding.""" + return json.dumps(data).encode(UTF_8_ENCODING) + b"\n" -def append_and_push_trace(ctx, message): +def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list: + """Deduplicate new_annotations using the annotations in the context.""" + deduped_annotations = [] + for annotation in new_annotations: + # Check if an annotation with the same content and address exists in ctx.annotations + # TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in + # to remove duplicates instead of using a custom logic. + # This is a temporary solution until the Invariant SDK is updated. + is_duplicate = False + for ctx_annotation in ctx.annotations: + if ( + annotation.content == ctx_annotation.content + and annotation.address == ctx_annotation.address + and annotation.extra_metadata == ctx_annotation.extra_metadata + ): + is_duplicate = True + break + + if not is_duplicate: + deduped_annotations.append(annotation) + + return deduped_annotations + + +def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool: + """Checks if there are new errors in the guardrails result.""" + annotations = create_annotations_from_guardrails_errors( + guardrails_result.get("errors", []) + ) + for annotation in annotations: + if annotation not in ctx.annotations: + return True + return False + + +async def append_and_push_trace( + ctx: McpContext, message: dict, guardrails_result: dict +) -> None: """ Append a message to the trace if it exists or create a new one and push it to the Invariant Explorer. + + This function runs asynchronously in the background. """ + + annotations = [] + if guardrails_result and guardrails_result.get("errors", []): + annotations = create_annotations_from_guardrails_errors( + guardrails_result["errors"] + ) + + if ctx.guardrails.logging_guardrails: + logging_guardrails_check_result = get_guardrails_check_result( + ctx, message, action=GuardrailAction.LOG + ) + if logging_guardrails_check_result and logging_guardrails_check_result.get( + "errors", [] + ): + annotations.extend( + create_annotations_from_guardrails_errors( + logging_guardrails_check_result["errors"] + ) + ) + deduplicated_annotations = deduplicate_annotations(ctx, annotations) + try: + # If the trace_id is None, create a new trace with the messages. + # Otherwise, append the message to the existing trace. + client = AsyncClient() if ctx.trace_id is None: ctx.trace.append(message) - response = ctx.client.create_request_and_push_trace( - messages=[ctx.trace], - dataset=ctx.explorer_dataset, - metadata=[{"source": "mcp", "tools": ctx.tools}], + response = await client.push_trace( + PushTracesRequest( + messages=[ctx.trace], + dataset=ctx.explorer_dataset, + metadata=[{"source": "mcp", "tools": ctx.tools}], + annotations=[deduplicated_annotations], + ) ) ctx.trace_id = response.id[0] ctx.last_trace_length = len(ctx.trace) + ctx.annotations.extend(deduplicated_annotations) else: ctx.trace.append(message) - ctx.client.create_request_and_append_messages( - trace_id=ctx.trace_id, messages=ctx.trace[ctx.last_trace_length :] + response = await client.append_messages( + AppendMessagesRequest( + trace_id=ctx.trace_id, + messages=ctx.trace[ctx.last_trace_length :], + annotations=deduplicated_annotations, + ) ) ctx.last_trace_length = len(ctx.trace) + ctx.annotations.extend(deduplicated_annotations) except Exception as e: - custom_print(ctx, "Error pushing trace:", e) + mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e) -def fetch_guardrails(ctx, dataset): - """Fetch guardrails from the Invariant Explorer.""" - # Use async fetch_guardrails_from_explorer in a thread - return asyncio.run( - fetch_guardrails_from_explorer( - dataset, "Bearer " + os.getenv("INVARIANT_API_KEY") - ) +def get_guardrails_check_result( + ctx: McpContext, + message: dict, + action: GuardrailAction = GuardrailAction.BLOCK, +) -> dict: + """ + Check against guardrails of type action. + Works in both sync and async contexts by always using a dedicated thread. + """ + # Skip if no guardrails are configured for this action + if not ( + (ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK) + or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG) + ): + return {} + + # Prepare context and select appropriate guardrails + context = RequestContext.create( + request_json={}, + dataset_name=ctx.explorer_dataset, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + guardrails=ctx.guardrails, + ) + + guardrails_to_check = ( + ctx.guardrails.blocking_guardrails + if action == GuardrailAction.BLOCK + else ctx.guardrails.logging_guardrails + ) + + return run_task_sync( + check_guardrails, + messages=ctx.trace + [message], + guardrails=guardrails_to_check, + context=context, ) -def check_blocking_guardrails(ctx, message, request): - """Check against blocking guardrails.""" - try: - guardrails = fetch_guardrails(ctx, ctx.explorer_dataset) - - custom_print(ctx, "Here are the guardrails: ", guardrails) - - context = RequestContext.create( - request_json=request, - dataset_name=ctx.explorer_dataset, - invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), - guardrails=guardrails, - ) - - if guardrails.blocking_guardrails: - with redirect_stdout(ctx.log_out): - return asyncio.run( - check_guardrails( - messages=ctx.trace + [message], - guardrails=guardrails.blocking_guardrails, - context=context, - ) - ) - else: - return {} - except Exception as e: - custom_print(ctx, "Error checking blocking guardrails:", e) - - -def hook_tool_call(ctx, request): +def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]: """ Hook function to intercept tool calls. - Modify this function to change behavior for tool calls. - Returns the potentially modified request. + + If the request is blocked, it returns a tuple with a message explaining the block + and a flag indicating the request was blocked. + Otherwise it returns the original request and a flag indicating it was not blocked. """ tool_call = { "id": f"call_{request.get('id')}", @@ -106,13 +188,39 @@ def hook_tool_call(ctx, request): message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} - # Check for blocking guardrails - result = check_blocking_guardrails(ctx, message, request) - append_and_push_trace(ctx, message) - return request + # Check for blocking guardrails - this blocks until completion + guardrailing_result = get_guardrails_check_result( + ctx, message, action=GuardrailAction.BLOCK + ) + + # If the request is blocked, return a message indicating the block reason. + # If there are new errors, run append_and_push_trace in background. + # If there are no new errors, just return the original request. + if ( + guardrailing_result + and guardrailing_result.get("errors", []) + and check_if_new_errors(ctx, guardrailing_result) + ): + if ctx.push_explorer: + run_task_in_background( + append_and_push_trace, ctx, message, guardrailing_result + ) + return { + "jsonrpc": "2.0", + "id": request.get("id"), + "error": { + "code": -32600, + "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE + % guardrailing_result["errors"], + }, + }, True + + # Add the message to the trace + ctx.trace.append(message) + return request, False -def hook_tool_result(ctx, result): +def hook_tool_result(ctx: McpContext, result: dict) -> dict: """ Hook function to intercept tool results. Modify this function to change behavior for tool results. @@ -123,153 +231,126 @@ def hook_tool_result(ctx, result): if method is None: return result - elif method == "tools/call": + elif method == MCP_TOOL_CALL: message = { "role": "tool", - "content": {"result": result.get("result").get("content")}, + "content": result.get("result").get("content"), "error": result.get("result").get("error"), "tool_call_id": call_id, } - - # Check for blocking guardrails - guardrailing_result = check_blocking_guardrails(ctx, message, result) + # Check for blocking guardrails - this blocks until completion + guardrailing_result = get_guardrails_check_result( + ctx, message, action=GuardrailAction.BLOCK + ) if guardrailing_result and guardrailing_result.get("errors", []): - result["result"]["content"] = [ - { - "type": "text", - "text": "[Invariant] Your MCP tool call was blocked for security reasons. Do not attempt to circumvent this block, rather explain to the user based on the following output what went wrong: \n" - + json.dumps(guardrailing_result["errors"]), - } - ] + result = { + "jsonrpc": "2.0", + "id": result.get("id"), + "error": { + "code": -32600, + "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE + % guardrailing_result["errors"], + }, + } - append_and_push_trace(ctx, message) + if ctx.push_explorer: + # Run append_and_push_trace in background + run_task_in_background( + append_and_push_trace, ctx, message, guardrailing_result + ) return result - elif method == "tools/list": + elif method == MCP_LIST_TOOLS: ctx.tools = result.get("result").get("tools") return result else: return result -def forward_stdout(process, ctx, buffer_size=1): - """Read from the process stdout, parse JSON chunks, and forward to sys.stdout""" - buffer = b"" - - while True: - chunk = process.stdout.read(buffer_size) - if not chunk: - break - buffer += chunk - +def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> None: + """Read from the mcp_process stdout, apply guardrails and and forward to sys.stdout""" + for line in iter(mcp_process.stdout.readline, b""): try: - # Try parsing full JSON object from buffer - text = buffer.decode("utf-8") - obj = json.loads(text) + # Process complete JSON lines + line_str = line.decode(UTF_8_ENCODING).strip() + if not line_str: + continue - obj = hook_tool_result(ctx, obj) - # clear the buffer - buffer = b"" + parsed_json = json.loads(line_str) + processed_json = hook_tool_result(ctx, parsed_json) - # Forward the original JSON to stdout - json_output = json.dumps(obj).encode("utf-8") + b"\n" - sys.stdout.buffer.write(json_output) + # Write and flush immediately + sys.stdout.buffer.write(write_as_utf8_bytes(processed_json)) sys.stdout.buffer.flush() - except (json.JSONDecodeError, UnicodeDecodeError): - # Wait for more data - continue + except json.JSONDecodeError as je: + mcp_log(f"[ERROR] JSON decode error in stdout processing: {str(je)}") + mcp_log(f"[ERROR] Problematic line: {line[:200]}...") + + except Exception as e: + mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}") + if line: + mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...") -def forward_stderr(process, ctx, buffer_size=1): - """Read from the process stderr and write to sys.stderr""" - for line in iter(lambda: process.stderr.read(buffer_size), b""): - ctx.log_out.buffer.write(line) - ctx.log_out.buffer.flush() +def stream_and_forward_stderr( + mcp_process: subprocess.Popen, ctx: McpContext, read_chunk_size: int = 1 +) -> None: + """Read from the mcp_process stderr and write to sys.stderr""" + for line in iter(lambda: mcp_process.stderr.read(read_chunk_size), b""): + MCP_LOG_FILE.buffer.write(line) + MCP_LOG_FILE.buffer.flush() -def execute(args=None): - """Main function to execute the MCP gateway.""" - if "INVARIANT_API_KEY" not in os.environ: - print("[ERROR] INVARIANT_API_KEY environment variable is not set.") - sys.exit(1) - - # Split args at the "--exec" boundary - if args and "--exec" in args: - exec_index = args.index("--exec") - pre_exec_args = args[:exec_index] - post_exec_args = args[exec_index + 1 :] - else: - pre_exec_args = args or [] - post_exec_args = [] - - if not post_exec_args: - print("[ERROR] No command provided after --exec.") - sys.exit(1) - - # Parse pre-exec args using argparse - parser = argparse.ArgumentParser(description="MCP Gateway") - parser.add_argument("--directory", help="Working directory") - parser.add_argument("--verbose", action="store_true", help="Enable verbose output") - config = parser.parse_args(pre_exec_args) - # Initialize the singleton context using config - ctx = McpContext() - - # Can now use post_exec_args as your cmd - cmd = post_exec_args - - process = subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - bufsize=0, # No buffering - ) - - # Start threads to forward stdout and stderr - stdout_thread = threading.Thread( - target=forward_stdout, args=(process, ctx), daemon=True - ) - stderr_thread = threading.Thread( - target=forward_stderr, args=(process, ctx), daemon=True - ) - stdout_thread.start() - stderr_thread.start() - - # Handle forwarding stdin and intercept tool calls +def run_stdio_input_loop(ctx: McpContext, mcp_process: subprocess.Popen) -> None: + """Handle standard input, intercept call and forward requests to mcp_process stdin.""" try: current_chunk = b"" while True: - data = sys.stdin.buffer.read(1) - current_chunk += data + buffer_input = sys.stdin.buffer.read(1) + current_chunk += buffer_input - if not data: + if not buffer_input: break # Try to decode and parse as JSON to check for tool calls try: - text = current_chunk.decode("utf-8") - obj = json.loads(text) - # clear the current chunk + text = current_chunk.decode(UTF_8_ENCODING) + parsed_json = json.loads(text) + # clear the current chunk after successful parse current_chunk = b"" + # Refresh guardrails + run_task_sync(ctx.load_guardrails) - if obj.get("method") is not None: - ctx.id_to_method_mapping[obj.get("id")] = obj.get("method") + if parsed_json.get(MCP_METHOD) is not None: + ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get( + MCP_METHOD + ) # Check if this is a tool call request - if obj.get("method") == "tools/call": - # Intercept and potentially modify the request - obj = hook_tool_call(ctx, obj) - # Convert back to bytes - data = json.dumps(obj).encode("utf-8") - - # Forward to the process - process.stdin.write(data + b"\n") - process.stdin.flush() + if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL: + # Intercept and potentially block modify the request + hook_tool_call_result, is_blocked = hook_tool_call(ctx, parsed_json) + if not is_blocked: + # If blocked, hook_tool_call_result contains the original request. + # Forward the request to the MCP process. + # It will handle the request and return a response. + mcp_process.stdin.write( + write_as_utf8_bytes(hook_tool_call_result) + ) + mcp_process.stdin.flush() + else: + # If blocked, hook_tool_call_result contains the block message. + # Forward the block message result back to the caller. + # The original request is not passed to the MCP process. + sys.stdout.buffer.write( + write_as_utf8_bytes(hook_tool_call_result) + ) + sys.stdout.buffer.flush() continue else: - process.stdin.write(json.dumps(obj).encode("utf-8") + b"\n") - process.stdin.flush() + mcp_process.stdin.write(write_as_utf8_bytes(parsed_json)) + mcp_process.stdin.flush() continue except Exception: # Not a complete or valid JSON, just pass through @@ -278,25 +359,69 @@ def execute(args=None): except BrokenPipeError: pass except KeyboardInterrupt: - process.terminate() + mcp_process.terminate() + + +def split_args(args: list[str] = None) -> tuple[list[str], list[str]]: + """ + Splits CLI arguments into two parts: + 1. Arguments intended for the MCP gateway (everything before `--exec`) + 2. Arguments for the underlying MCP server (everything after `--exec`) + + Parameters: + args (list[str]): The list of CLI arguments. + + Returns: + Tuple[list[str], list[str]]: A tuple containing (mcp_gateway_args, mcp_server_command_args) + """ + if not args: + mcp_log("[ERROR] No arguments provided.") + sys.exit(1) - # Wait for process to terminate try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() - process.wait() + exec_index = args.index("--exec") + except ValueError: + mcp_log("[ERROR] '--exec' flag not found in arguments.") + sys.exit(1) + + mcp_gateway_args = args[:exec_index] + mcp_server_command_args = args[exec_index + 1 :] + + if not mcp_server_command_args: + mcp_log("[ERROR] No arguments provided after '--exec'.") + sys.exit(1) + + return mcp_gateway_args, mcp_server_command_args -# Handle signals to ensure clean shutdown -def signal_handler(sig, frame): - """Handle signals for graceful shutdown.""" - ctx = McpContext() - custom_print(ctx, f"Received signal {sig}, shutting down...") - sys.exit(0) +async def execute(args: list[str] = None): + """Main function to execute the MCP gateway.""" + if "INVARIANT_API_KEY" not in os.environ: + mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.") + sys.exit(1) + mcp_gateway_args, mcp_server_command_args = split_args(args) + ctx = McpContext(mcp_gateway_args) -if __name__ == "__main__": - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - execute(sys.argv) + mcp_process = subprocess.Popen( + mcp_server_command_args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=0, + ) + + # Start threads to forward stdout and stderr + threading.Thread( + target=stream_and_forward_stdout, + args=(mcp_process, ctx), + daemon=True, + ).start() + threading.Thread( + target=stream_and_forward_stderr, + args=(mcp_process, ctx), + daemon=True, + ).start() + + # Handle forwarding stdin and intercept tool calls + run_stdio_input_loop(ctx, mcp_process) diff --git a/gateway/mcp/mcp_context.py b/gateway/mcp/mcp_context.py index 43eb849..a4dced6 100644 --- a/gateway/mcp/mcp_context.py +++ b/gateway/mcp/mcp_context.py @@ -1,9 +1,13 @@ """Context manager for MCP (Model Context Protocol) gateway.""" -import atexit +import argparse import os -import sys -from invariant_sdk.client import Client +import random + +from gateway.integrations.explorer import ( + fetch_guardrails_from_explorer, +) +from gateway.common.guardrails import GuardrailRuleSet class McpContext: @@ -11,44 +15,53 @@ class McpContext: _instance = None - def __new__(cls): - """Control instance creation to ensure only one instance exists.""" + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(McpContext, cls).__new__(cls) cls._instance._initialized = False return cls._instance - def __init__(self): - """Initialize the singleton instance with default values (only once).""" - # Define _initialized attribute explicitly at the beginning to avoid warnings - # This is redundant but prevents warnings about accessing before definition + def __init__(self, cli_args: list): if not hasattr(self, "_initialized"): self._initialized = False - if self._initialized: return - def setup_logging(self): - """Set up logging to a file in the user's home directory. - - Uses proper resource management to ensure the file is closed on program exit. - """ - os.makedirs( - os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True - ) - log_path = os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log") - self.log_out = open(log_path, "a", buffering=1, encoding="utf-8") - atexit.register(self.log_out.close) - sys.stderr = self.log_out - - self.client = Client() - self.explorer_dataset = "mcp-capture" + config = self._parse_cli_args(cli_args) + self.explorer_dataset = config.dataset_name + self.push_explorer = config.push_explorer self.trace = [] self.tools = [] + self.guardrails = GuardrailRuleSet( + blocking_guardrails=[], logging_guardrails=[] + ) + # We send the same trace messages for guardrails analysis multiple times. + # We need to deduplicate them before sending to the explorer. + self.annotations = [] self.trace_id = None self.last_trace_length = 0 - self.guardrails = None self.id_to_method_mapping = {} - setup_logging(self) - # Mark as initialized self._initialized = True + + def _parse_cli_args(self, cli_args: list) -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="MCP Gateway") + parser.add_argument( + "--dataset-name", + help="Name of the dataset where we want to push the MCP traces", + type=str, + default=f"mcp-capture-{random.randint(1, 100)}", + ) + parser.add_argument( + "--push-explorer", + help="Enable pushing traces to Invariant Explorer", + action="store_true", + ) + + return parser.parse_args(cli_args) + + async def load_guardrails(self): + """Run async setup logic (e.g. fetching guardrails).""" + self.guardrails = await fetch_guardrails_from_explorer( + self.explorer_dataset, "Bearer " + os.getenv("INVARIANT_API_KEY") + ) diff --git a/gateway/mcp/task_utils.py b/gateway/mcp/task_utils.py new file mode 100644 index 0000000..b507d0c --- /dev/null +++ b/gateway/mcp/task_utils.py @@ -0,0 +1,72 @@ +"""Task utilities for running async functions""" + +import asyncio +import concurrent.futures +import threading + +from contextlib import redirect_stdout +from typing import Any + +from gateway.mcp.log import MCP_LOG_FILE, mcp_log + + +def run_task_in_background(async_func, *args, **kwargs): + """ + Runs an async function in a background thread with its own event loop. + This function does NOT block the calling thread as it immediately returns + after starting the background thread. + + Args: + async_func: The async function to run + *args: Positional arguments to pass to the async function + **kwargs: Keyword arguments to pass to the async function + """ + + def thread_target(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(async_func(*args, **kwargs)) + except Exception as e: + mcp_log( + f"[ERROR] Error in async thread while running run_task_in_background: {e}" + ) + finally: + loop.close() + + # Create and start a daemon thread + thread = threading.Thread(target=thread_target, daemon=True) + thread.start() + + +def run_task_sync(async_func, *args, **kwargs) -> Any: + """ + Runs an asynchronous function synchronously in a separate + thread with its own event loop. This function blocks the calling + thread until completion or timeout (10 seconds). + + Args: + async_func: The async function to run + *args: Positional arguments to pass to the async function + **kwargs: Keyword arguments to pass to the async function + + Returns: + Any: The return value of the async function + """ + + def run_in_new_loop(): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete( + async_func( + *args, + **kwargs, + ) + ) + finally: + loop.close() + + with redirect_stdout(MCP_LOG_FILE): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_new_loop) + return future.result(timeout=10.0) diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 950f15b..e294a24 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers from gateway.common.config_manager import ( GatewayConfig, GatewayConfigManager, - GuardrailsInHeader, + extract_guardrails_from_header, ) from gateway.common.constants import ( CLIENT_TIMEOUT, @@ -34,7 +34,6 @@ from gateway.integrations.guardrails import ( InstrumentedStreamingResponse, Replacement, check_guardrails, - preload_guardrails, ) gateway = APIRouter() @@ -70,7 +69,7 @@ async def anthropic_v1_messages_gateway( request: Request, dataset_name: str = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument - header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader), + header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ): """Proxy calls to the Anthropic APIs""" headers = { @@ -171,7 +170,7 @@ async def push_to_explorer( """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} annotations = create_annotations_from_guardrails_errors( - guardrails_execution_result.get("errors", []), action="block" + guardrails_execution_result.get("errors", []) ) # Execute the logging guardrails before pushing to Explorer @@ -181,7 +180,7 @@ async def push_to_explorer( response_json=merged_response, ) logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []), action="log" + logging_guardrails_execution_result.get("errors", []) ) # Update the annotations with the logging guardrails annotations.extend(logging_annotations) diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 06814c7..9a836e7 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers from gateway.common.config_manager import ( GatewayConfig, GatewayConfigManager, - GuardrailsInHeader, + extract_guardrails_from_header, ) from gateway.common.constants import ( CLIENT_TIMEOUT, @@ -31,7 +31,6 @@ from gateway.integrations.guardrails import ( InstrumentedResponse, InstrumentedStreamingResponse, Replacement, - preload_guardrails, check_guardrails, ) @@ -53,7 +52,7 @@ async def gemini_generate_content_gateway( None, title="Response Format", description="Set to 'sse' for streaming" ), config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument - header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader), + header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ) -> Response: """Proxy calls to the Gemini GenerateContent API""" if endpoint not in ["generateContent", "streamGenerateContent"]: @@ -413,7 +412,7 @@ async def push_to_explorer( """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} annotations = create_annotations_from_guardrails_errors( - guardrails_execution_result.get("errors", []), action="block" + guardrails_execution_result.get("errors", []) ) # Execute the logging guardrails before pushing to Explorer @@ -423,7 +422,7 @@ async def push_to_explorer( response_json=response_json, ) logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []), action="log" + logging_guardrails_execution_result.get("errors", []) ) # Update the annotations with the logging guardrails annotations.extend(logging_annotations) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index b849dab..4d25871 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers from gateway.common.config_manager import ( GatewayConfig, GatewayConfigManager, - GuardrailsInHeader, + extract_guardrails_from_header, ) from gateway.common.constants import ( CLIENT_TIMEOUT, @@ -30,7 +30,6 @@ from gateway.integrations.guardrails import ( InstrumentedResponse, InstrumentedStreamingResponse, check_guardrails, - preload_guardrails, ) gateway = APIRouter() @@ -113,7 +112,7 @@ async def openai_chat_completions_gateway( request: Request, dataset_name: str = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument - header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader), + header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ) -> Response: """Proxy calls to the OpenAI APIs""" headers = { @@ -491,9 +490,7 @@ async def push_to_explorer( # or if the guardrails check returned errors. guardrails_execution_result = guardrails_execution_result or {} guardrails_errors = guardrails_execution_result.get("errors", []) - annotations = create_annotations_from_guardrails_errors( - guardrails_errors, action="block" - ) + annotations = create_annotations_from_guardrails_errors(guardrails_errors) # Execute the logging guardrails before pushing to Explorer logging_guardrails_execution_result = await get_guardrails_check_result( context, @@ -501,7 +498,7 @@ async def push_to_explorer( response_json=merged_response, ) logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []), action="log" + logging_guardrails_execution_result.get("errors", []) ) # Update the annotations with the logging guardrails annotations.extend(logging_annotations)