Fix asyncio and threading. Dedupe annotations before pushing. Add README.

This commit is contained in:
Hemang
2025-04-15 17:25:35 +02:00
committed by Hemang Sarkar
parent f871e24473
commit 2c34205c4c
11 changed files with 593 additions and 252 deletions
+30 -12
View File
@@ -1,34 +1,52 @@
"""Script is used to run actions using the Invariant Gateway."""
import asyncio
import signal
import sys
from gateway.mcp import mcp
def main():
"""Entry point for the Invariant Gateway."""
# Handle signals to ensure clean shutdown
def signal_handler(sig, frame):
"""Handle signals for graceful shutdown."""
sys.exit(0)
def print_help():
"""Prints the help message."""
actions = {
"mcp": "Runs the Invariant Gateway against MCP (Model Context Protocol) servers with guardrailing and push to Explorer features",
"llm": "Runs the Invariant Gateway against LLM providers with guardrailing and push to Explorer features",
"help": "Shows this help message",
"mcp": "Runs the Invariant Gateway against MCP (Model Context Protocol) servers with guardrailing and push to Explorer features.",
"llm": "Runs the Invariant Gateway against LLM providers with guardrailing and push to Explorer features. Not implemented yet.",
"help": "Shows this help message.",
}
def _help():
"""_prints the help message."""
for verb, description in actions.items():
print(f"{verb}: {description}")
for verb, description in actions.items():
print(f"{verb}: {description}")
def main():
"""Entry point for the Invariant Gateway."""
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if len(sys.argv) < 2:
_help()
print_help()
sys.exit(1)
verb = sys.argv[1]
if verb == "mcp":
return mcp.execute(sys.argv[2:])
print("[MCP] Running Invariant Gateway against MCP servers...")
return asyncio.run(mcp.execute(sys.argv[2:]))
if verb == "llm":
print("[gateway/__main__.py] 'llm' action is not implemented yet.")
return 1
if verb == "help":
_help()
print_help()
return 0
print(f"[gateway/__main__.py] Unknown action: {verb}")
return 1
if __name__ == "__main__":
sys.exit(main())
+3 -1
View File
@@ -93,7 +93,9 @@ class GatewayConfigManager:
return local_config
async def GuardrailsInHeader(request: fastapi.Request) -> Optional[GuardrailRuleSet]:
async def extract_guardrails_from_header(
request: fastapi.Request,
) -> Optional[GuardrailRuleSet]:
"""
Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding
GuardrailRuleSet. If no guardrails are provided, returns None.
+40 -6
View File
@@ -1,8 +1,9 @@
"""Utility functions for the Invariant explorer."""
import os
from typing import Any, Dict, List
import json
from typing import Any, Dict, List
from fastapi import HTTPException
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
@@ -16,17 +17,20 @@ DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
def create_annotations_from_guardrails_errors(
guardrails_errors: List[dict], action: str = "block"
guardrails_errors: List[dict],
) -> List[AnnotationCreate]:
"""Create Explorer annotations from the guardrails errors."""
annotations = []
def _remove_prefixes(ranges: list[str]) -> list[str]:
def _pick_most_specific_ranges(ranges: list[str]) -> list[str]:
"""
Remove prefixes from the list of ranges.
Remove redundant prefixes from the list of ranges.
If the ranges are ['messages.2', 'messages.2.content:25-30', 'messages.2.content']
then this returns ['messages.2.content:25-30'].
This picks the most specific subset of the ranges and removes the rest. If some
range is a proper prefix of another range, it is removed.
"""
ranges = sorted(ranges, key=len)
result = []
@@ -44,7 +48,7 @@ def create_annotations_from_guardrails_errors(
for error in guardrails_errors:
content = error.get("args")[0]
filtered_ranges = _remove_prefixes(list(error.get("ranges", [])))
filtered_ranges = _pick_most_specific_ranges(list(error.get("ranges", [])))
for r in filtered_ranges:
annotations.append(
AnnotationCreate(
@@ -61,10 +65,40 @@ def create_annotations_from_guardrails_errors(
},
)
)
return annotations
# Remove duplicates
# TODO: Rely on the __eq__ and __hash__ methods of the AnnotationCreate class
# to remove duplicates instead of using a custom function.
# This is a temporary solution until the Invariant SDK is updated.
return remove_duplicates(annotations)
def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]:
"""
Remove duplicate annotations based on content, address, and extra_metadata.
Two annotations are considered duplicates if they have the same content,
address, and extra_metadata.
"""
unique_annotations = []
seen = set()
for annotation in annotations:
# Convert the entire extra_metadata dict to a JSON string
# This creates a hashable representation regardless of nested content
metadata_str = json.dumps(annotation.extra_metadata, sort_keys=True)
# Create a unique identifier using all three fields
unique_key = (annotation.content, annotation.address, metadata_str)
if unique_key not in seen:
seen.add(unique_key)
unique_annotations.append(annotation)
return unique_annotations
def get_explorer_api_url() -> str:
"""Get the Invariant Explorer API URL from the environment variable."""
return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
+63
View File
@@ -0,0 +1,63 @@
This is a work in progress implementation for MCP (Model Context Protocol) with the Gateway.
This repository will be pushed to PyPi and then using it with the MCP config file will be simpler.
For now if the original MCP config file looks like:
```
{
"mcpServers": {
"weather": {
"command": "uv",
"args": [
"--directory",
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
"run",
"weather.py"
]
}
}
}
```
You need to:
1. Checkout the invariant-gatway repo.
2. Run `python -m build`. This will generate a .whl file in dist.
3. Modify the MCP config like this:
```
{
"mcpServers": {
"weather": {
"command": "uvx",
"args": [
"--refresh",
"--from",
"/ABSOLUTE/PATH/TO/INVARIANT_GATEWAY_REPO/dist/invariant_gateway-0.0.1-py3-none-any.whl",
"invariant-gateway",
"mcp",
"--dataset-name",
"weather-testing",
"--push-explorer",
"--exec",
"uv",
"--directory",
"/Users/hemang/Sdk/mcp/weather",
"run",
"weather.py"
],
"env": {
"INVARIANT_API_KEY": "<Add Invariant API key here>"
}
}
}
}
```
This moves the original `command` and `args` to the `args` list after the `--exec` flag.
All args before the `--exec` flag are relevant to the Invariant MCP gateway. These include:
- `--dataset-name`: With this you can specify the name of the dataset. The guardrails are pulled from this dataset.
- `--push-explorer`: With this you can specify if you want to push the annotated traces to the Invariant Explorer.
+19
View File
@@ -0,0 +1,19 @@
"""Cusym log configuration."""
import os
import sys
from builtins import print as builtins_print
os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True)
MCP_LOG_FILE = open(
os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log"),
"a",
buffering=1,
)
sys.stderr = MCP_LOG_FILE
def mcp_log(*args, **kwargs) -> None:
"""Custom print function to redirect output to log_out."""
builtins_print(*args, **kwargs, file=MCP_LOG_FILE, flush=True)
+313 -188
View File
@@ -1,99 +1,181 @@
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
import argparse
import asyncio
import sys
import subprocess
import json
import os
import threading
import signal
from builtins import print as builtins_print
from contextlib import redirect_stdout
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
from invariant_sdk.types.push_traces import PushTracesRequest
from gateway.common.guardrails import GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
from gateway.integrations.guardrails import check_guardrails
from gateway.integrations.explorer import (
fetch_guardrails_from_explorer,
)
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
from gateway.mcp.mcp_context import McpContext
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
MCP_METHOD = "method"
UTF_8_ENCODING = "utf-8"
MCP_TOOL_CALL = "tools/call"
MCP_LIST_TOOLS = "tools/list"
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """
[Invariant Guardrails] The MCP tool call was blocked for security reasons.
Do not attempt to circumvent this block, rather explain to the user based
on the following output what went wrong: %s
"""
def custom_print(ctx, *args, **kwargs):
"""Custom print function to redirect output to log_out."""
builtins_print(*args, **kwargs, file=ctx.log_out, flush=True)
def write_as_utf8_bytes(data: dict) -> bytes:
"""Serializes dict to bytes using UTF-8 encoding."""
return json.dumps(data).encode(UTF_8_ENCODING) + b"\n"
def append_and_push_trace(ctx, message):
def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list:
"""Deduplicate new_annotations using the annotations in the context."""
deduped_annotations = []
for annotation in new_annotations:
# Check if an annotation with the same content and address exists in ctx.annotations
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
# to remove duplicates instead of using a custom logic.
# This is a temporary solution until the Invariant SDK is updated.
is_duplicate = False
for ctx_annotation in ctx.annotations:
if (
annotation.content == ctx_annotation.content
and annotation.address == ctx_annotation.address
and annotation.extra_metadata == ctx_annotation.extra_metadata
):
is_duplicate = True
break
if not is_duplicate:
deduped_annotations.append(annotation)
return deduped_annotations
def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
"""Checks if there are new errors in the guardrails result."""
annotations = create_annotations_from_guardrails_errors(
guardrails_result.get("errors", [])
)
for annotation in annotations:
if annotation not in ctx.annotations:
return True
return False
async def append_and_push_trace(
ctx: McpContext, message: dict, guardrails_result: dict
) -> None:
"""
Append a message to the trace if it exists or create a new one
and push it to the Invariant Explorer.
This function runs asynchronously in the background.
"""
annotations = []
if guardrails_result and guardrails_result.get("errors", []):
annotations = create_annotations_from_guardrails_errors(
guardrails_result["errors"]
)
if ctx.guardrails.logging_guardrails:
logging_guardrails_check_result = get_guardrails_check_result(
ctx, message, action=GuardrailAction.LOG
)
if logging_guardrails_check_result and logging_guardrails_check_result.get(
"errors", []
):
annotations.extend(
create_annotations_from_guardrails_errors(
logging_guardrails_check_result["errors"]
)
)
deduplicated_annotations = deduplicate_annotations(ctx, annotations)
try:
# If the trace_id is None, create a new trace with the messages.
# Otherwise, append the message to the existing trace.
client = AsyncClient()
if ctx.trace_id is None:
ctx.trace.append(message)
response = ctx.client.create_request_and_push_trace(
messages=[ctx.trace],
dataset=ctx.explorer_dataset,
metadata=[{"source": "mcp", "tools": ctx.tools}],
response = await client.push_trace(
PushTracesRequest(
messages=[ctx.trace],
dataset=ctx.explorer_dataset,
metadata=[{"source": "mcp", "tools": ctx.tools}],
annotations=[deduplicated_annotations],
)
)
ctx.trace_id = response.id[0]
ctx.last_trace_length = len(ctx.trace)
ctx.annotations.extend(deduplicated_annotations)
else:
ctx.trace.append(message)
ctx.client.create_request_and_append_messages(
trace_id=ctx.trace_id, messages=ctx.trace[ctx.last_trace_length :]
response = await client.append_messages(
AppendMessagesRequest(
trace_id=ctx.trace_id,
messages=ctx.trace[ctx.last_trace_length :],
annotations=deduplicated_annotations,
)
)
ctx.last_trace_length = len(ctx.trace)
ctx.annotations.extend(deduplicated_annotations)
except Exception as e:
custom_print(ctx, "Error pushing trace:", e)
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
def fetch_guardrails(ctx, dataset):
"""Fetch guardrails from the Invariant Explorer."""
# Use async fetch_guardrails_from_explorer in a thread
return asyncio.run(
fetch_guardrails_from_explorer(
dataset, "Bearer " + os.getenv("INVARIANT_API_KEY")
)
def get_guardrails_check_result(
ctx: McpContext,
message: dict,
action: GuardrailAction = GuardrailAction.BLOCK,
) -> dict:
"""
Check against guardrails of type action.
Works in both sync and async contexts by always using a dedicated thread.
"""
# Skip if no guardrails are configured for this action
if not (
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
):
return {}
# Prepare context and select appropriate guardrails
context = RequestContext.create(
request_json={},
dataset_name=ctx.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=ctx.guardrails,
)
guardrails_to_check = (
ctx.guardrails.blocking_guardrails
if action == GuardrailAction.BLOCK
else ctx.guardrails.logging_guardrails
)
return run_task_sync(
check_guardrails,
messages=ctx.trace + [message],
guardrails=guardrails_to_check,
context=context,
)
def check_blocking_guardrails(ctx, message, request):
"""Check against blocking guardrails."""
try:
guardrails = fetch_guardrails(ctx, ctx.explorer_dataset)
custom_print(ctx, "Here are the guardrails: ", guardrails)
context = RequestContext.create(
request_json=request,
dataset_name=ctx.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=guardrails,
)
if guardrails.blocking_guardrails:
with redirect_stdout(ctx.log_out):
return asyncio.run(
check_guardrails(
messages=ctx.trace + [message],
guardrails=guardrails.blocking_guardrails,
context=context,
)
)
else:
return {}
except Exception as e:
custom_print(ctx, "Error checking blocking guardrails:", e)
def hook_tool_call(ctx, request):
def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
"""
Hook function to intercept tool calls.
Modify this function to change behavior for tool calls.
Returns the potentially modified request.
If the request is blocked, it returns a tuple with a message explaining the block
and a flag indicating the request was blocked.
Otherwise it returns the original request and a flag indicating it was not blocked.
"""
tool_call = {
"id": f"call_{request.get('id')}",
@@ -106,13 +188,39 @@ def hook_tool_call(ctx, request):
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails
result = check_blocking_guardrails(ctx, message, request)
append_and_push_trace(ctx, message)
return request
# Check for blocking guardrails - this blocks until completion
guardrailing_result = get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
# If the request is blocked, return a message indicating the block reason.
# If there are new errors, run append_and_push_trace in background.
# If there are no new errors, just return the original request.
if (
guardrailing_result
and guardrailing_result.get("errors", [])
and check_if_new_errors(ctx, guardrailing_result)
):
if ctx.push_explorer:
run_task_in_background(
append_and_push_trace, ctx, message, guardrailing_result
)
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrailing_result["errors"],
},
}, True
# Add the message to the trace
ctx.trace.append(message)
return request, False
def hook_tool_result(ctx, result):
def hook_tool_result(ctx: McpContext, result: dict) -> dict:
"""
Hook function to intercept tool results.
Modify this function to change behavior for tool results.
@@ -123,153 +231,126 @@ def hook_tool_result(ctx, result):
if method is None:
return result
elif method == "tools/call":
elif method == MCP_TOOL_CALL:
message = {
"role": "tool",
"content": {"result": result.get("result").get("content")},
"content": result.get("result").get("content"),
"error": result.get("result").get("error"),
"tool_call_id": call_id,
}
# Check for blocking guardrails
guardrailing_result = check_blocking_guardrails(ctx, message, result)
# Check for blocking guardrails - this blocks until completion
guardrailing_result = get_guardrails_check_result(
ctx, message, action=GuardrailAction.BLOCK
)
if guardrailing_result and guardrailing_result.get("errors", []):
result["result"]["content"] = [
{
"type": "text",
"text": "[Invariant] Your MCP tool call was blocked for security reasons. Do not attempt to circumvent this block, rather explain to the user based on the following output what went wrong: \n"
+ json.dumps(guardrailing_result["errors"]),
}
]
result = {
"jsonrpc": "2.0",
"id": result.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrailing_result["errors"],
},
}
append_and_push_trace(ctx, message)
if ctx.push_explorer:
# Run append_and_push_trace in background
run_task_in_background(
append_and_push_trace, ctx, message, guardrailing_result
)
return result
elif method == "tools/list":
elif method == MCP_LIST_TOOLS:
ctx.tools = result.get("result").get("tools")
return result
else:
return result
def forward_stdout(process, ctx, buffer_size=1):
"""Read from the process stdout, parse JSON chunks, and forward to sys.stdout"""
buffer = b""
while True:
chunk = process.stdout.read(buffer_size)
if not chunk:
break
buffer += chunk
def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> None:
"""Read from the mcp_process stdout, apply guardrails and and forward to sys.stdout"""
for line in iter(mcp_process.stdout.readline, b""):
try:
# Try parsing full JSON object from buffer
text = buffer.decode("utf-8")
obj = json.loads(text)
# Process complete JSON lines
line_str = line.decode(UTF_8_ENCODING).strip()
if not line_str:
continue
obj = hook_tool_result(ctx, obj)
# clear the buffer
buffer = b""
parsed_json = json.loads(line_str)
processed_json = hook_tool_result(ctx, parsed_json)
# Forward the original JSON to stdout
json_output = json.dumps(obj).encode("utf-8") + b"\n"
sys.stdout.buffer.write(json_output)
# Write and flush immediately
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
sys.stdout.buffer.flush()
except (json.JSONDecodeError, UnicodeDecodeError):
# Wait for more data
continue
except json.JSONDecodeError as je:
mcp_log(f"[ERROR] JSON decode error in stdout processing: {str(je)}")
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
except Exception as e:
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
if line:
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
def forward_stderr(process, ctx, buffer_size=1):
"""Read from the process stderr and write to sys.stderr"""
for line in iter(lambda: process.stderr.read(buffer_size), b""):
ctx.log_out.buffer.write(line)
ctx.log_out.buffer.flush()
def stream_and_forward_stderr(
mcp_process: subprocess.Popen, ctx: McpContext, read_chunk_size: int = 1
) -> None:
"""Read from the mcp_process stderr and write to sys.stderr"""
for line in iter(lambda: mcp_process.stderr.read(read_chunk_size), b""):
MCP_LOG_FILE.buffer.write(line)
MCP_LOG_FILE.buffer.flush()
def execute(args=None):
"""Main function to execute the MCP gateway."""
if "INVARIANT_API_KEY" not in os.environ:
print("[ERROR] INVARIANT_API_KEY environment variable is not set.")
sys.exit(1)
# Split args at the "--exec" boundary
if args and "--exec" in args:
exec_index = args.index("--exec")
pre_exec_args = args[:exec_index]
post_exec_args = args[exec_index + 1 :]
else:
pre_exec_args = args or []
post_exec_args = []
if not post_exec_args:
print("[ERROR] No command provided after --exec.")
sys.exit(1)
# Parse pre-exec args using argparse
parser = argparse.ArgumentParser(description="MCP Gateway")
parser.add_argument("--directory", help="Working directory")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
config = parser.parse_args(pre_exec_args)
# Initialize the singleton context using config
ctx = McpContext()
# Can now use post_exec_args as your cmd
cmd = post_exec_args
process = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=0, # No buffering
)
# Start threads to forward stdout and stderr
stdout_thread = threading.Thread(
target=forward_stdout, args=(process, ctx), daemon=True
)
stderr_thread = threading.Thread(
target=forward_stderr, args=(process, ctx), daemon=True
)
stdout_thread.start()
stderr_thread.start()
# Handle forwarding stdin and intercept tool calls
def run_stdio_input_loop(ctx: McpContext, mcp_process: subprocess.Popen) -> None:
"""Handle standard input, intercept call and forward requests to mcp_process stdin."""
try:
current_chunk = b""
while True:
data = sys.stdin.buffer.read(1)
current_chunk += data
buffer_input = sys.stdin.buffer.read(1)
current_chunk += buffer_input
if not data:
if not buffer_input:
break
# Try to decode and parse as JSON to check for tool calls
try:
text = current_chunk.decode("utf-8")
obj = json.loads(text)
# clear the current chunk
text = current_chunk.decode(UTF_8_ENCODING)
parsed_json = json.loads(text)
# clear the current chunk after successful parse
current_chunk = b""
# Refresh guardrails
run_task_sync(ctx.load_guardrails)
if obj.get("method") is not None:
ctx.id_to_method_mapping[obj.get("id")] = obj.get("method")
if parsed_json.get(MCP_METHOD) is not None:
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
MCP_METHOD
)
# Check if this is a tool call request
if obj.get("method") == "tools/call":
# Intercept and potentially modify the request
obj = hook_tool_call(ctx, obj)
# Convert back to bytes
data = json.dumps(obj).encode("utf-8")
# Forward to the process
process.stdin.write(data + b"\n")
process.stdin.flush()
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
# Intercept and potentially block modify the request
hook_tool_call_result, is_blocked = hook_tool_call(ctx, parsed_json)
if not is_blocked:
# If blocked, hook_tool_call_result contains the original request.
# Forward the request to the MCP process.
# It will handle the request and return a response.
mcp_process.stdin.write(
write_as_utf8_bytes(hook_tool_call_result)
)
mcp_process.stdin.flush()
else:
# If blocked, hook_tool_call_result contains the block message.
# Forward the block message result back to the caller.
# The original request is not passed to the MCP process.
sys.stdout.buffer.write(
write_as_utf8_bytes(hook_tool_call_result)
)
sys.stdout.buffer.flush()
continue
else:
process.stdin.write(json.dumps(obj).encode("utf-8") + b"\n")
process.stdin.flush()
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
mcp_process.stdin.flush()
continue
except Exception:
# Not a complete or valid JSON, just pass through
@@ -278,25 +359,69 @@ def execute(args=None):
except BrokenPipeError:
pass
except KeyboardInterrupt:
process.terminate()
mcp_process.terminate()
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
"""
Splits CLI arguments into two parts:
1. Arguments intended for the MCP gateway (everything before `--exec`)
2. Arguments for the underlying MCP server (everything after `--exec`)
Parameters:
args (list[str]): The list of CLI arguments.
Returns:
Tuple[list[str], list[str]]: A tuple containing (mcp_gateway_args, mcp_server_command_args)
"""
if not args:
mcp_log("[ERROR] No arguments provided.")
sys.exit(1)
# Wait for process to terminate
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
exec_index = args.index("--exec")
except ValueError:
mcp_log("[ERROR] '--exec' flag not found in arguments.")
sys.exit(1)
mcp_gateway_args = args[:exec_index]
mcp_server_command_args = args[exec_index + 1 :]
if not mcp_server_command_args:
mcp_log("[ERROR] No arguments provided after '--exec'.")
sys.exit(1)
return mcp_gateway_args, mcp_server_command_args
# Handle signals to ensure clean shutdown
def signal_handler(sig, frame):
"""Handle signals for graceful shutdown."""
ctx = McpContext()
custom_print(ctx, f"Received signal {sig}, shutting down...")
sys.exit(0)
async def execute(args: list[str] = None):
"""Main function to execute the MCP gateway."""
if "INVARIANT_API_KEY" not in os.environ:
mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.")
sys.exit(1)
mcp_gateway_args, mcp_server_command_args = split_args(args)
ctx = McpContext(mcp_gateway_args)
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
execute(sys.argv)
mcp_process = subprocess.Popen(
mcp_server_command_args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=0,
)
# Start threads to forward stdout and stderr
threading.Thread(
target=stream_and_forward_stdout,
args=(mcp_process, ctx),
daemon=True,
).start()
threading.Thread(
target=stream_and_forward_stderr,
args=(mcp_process, ctx),
daemon=True,
).start()
# Handle forwarding stdin and intercept tool calls
run_stdio_input_loop(ctx, mcp_process)
+41 -28
View File
@@ -1,9 +1,13 @@
"""Context manager for MCP (Model Context Protocol) gateway."""
import atexit
import argparse
import os
import sys
from invariant_sdk.client import Client
import random
from gateway.integrations.explorer import (
fetch_guardrails_from_explorer,
)
from gateway.common.guardrails import GuardrailRuleSet
class McpContext:
@@ -11,44 +15,53 @@ class McpContext:
_instance = None
def __new__(cls):
"""Control instance creation to ensure only one instance exists."""
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(McpContext, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Initialize the singleton instance with default values (only once)."""
# Define _initialized attribute explicitly at the beginning to avoid warnings
# This is redundant but prevents warnings about accessing before definition
def __init__(self, cli_args: list):
if not hasattr(self, "_initialized"):
self._initialized = False
if self._initialized:
return
def setup_logging(self):
"""Set up logging to a file in the user's home directory.
Uses proper resource management to ensure the file is closed on program exit.
"""
os.makedirs(
os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True
)
log_path = os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log")
self.log_out = open(log_path, "a", buffering=1, encoding="utf-8")
atexit.register(self.log_out.close)
sys.stderr = self.log_out
self.client = Client()
self.explorer_dataset = "mcp-capture"
config = self._parse_cli_args(cli_args)
self.explorer_dataset = config.dataset_name
self.push_explorer = config.push_explorer
self.trace = []
self.tools = []
self.guardrails = GuardrailRuleSet(
blocking_guardrails=[], logging_guardrails=[]
)
# We send the same trace messages for guardrails analysis multiple times.
# We need to deduplicate them before sending to the explorer.
self.annotations = []
self.trace_id = None
self.last_trace_length = 0
self.guardrails = None
self.id_to_method_mapping = {}
setup_logging(self)
# Mark as initialized
self._initialized = True
def _parse_cli_args(self, cli_args: list) -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="MCP Gateway")
parser.add_argument(
"--dataset-name",
help="Name of the dataset where we want to push the MCP traces",
type=str,
default=f"mcp-capture-{random.randint(1, 100)}",
)
parser.add_argument(
"--push-explorer",
help="Enable pushing traces to Invariant Explorer",
action="store_true",
)
return parser.parse_args(cli_args)
async def load_guardrails(self):
"""Run async setup logic (e.g. fetching guardrails)."""
self.guardrails = await fetch_guardrails_from_explorer(
self.explorer_dataset, "Bearer " + os.getenv("INVARIANT_API_KEY")
)
+72
View File
@@ -0,0 +1,72 @@
"""Task utilities for running async functions"""
import asyncio
import concurrent.futures
import threading
from contextlib import redirect_stdout
from typing import Any
from gateway.mcp.log import MCP_LOG_FILE, mcp_log
def run_task_in_background(async_func, *args, **kwargs):
"""
Runs an async function in a background thread with its own event loop.
This function does NOT block the calling thread as it immediately returns
after starting the background thread.
Args:
async_func: The async function to run
*args: Positional arguments to pass to the async function
**kwargs: Keyword arguments to pass to the async function
"""
def thread_target():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(async_func(*args, **kwargs))
except Exception as e:
mcp_log(
f"[ERROR] Error in async thread while running run_task_in_background: {e}"
)
finally:
loop.close()
# Create and start a daemon thread
thread = threading.Thread(target=thread_target, daemon=True)
thread.start()
def run_task_sync(async_func, *args, **kwargs) -> Any:
"""
Runs an asynchronous function synchronously in a separate
thread with its own event loop. This function blocks the calling
thread until completion or timeout (10 seconds).
Args:
async_func: The async function to run
*args: Positional arguments to pass to the async function
**kwargs: Keyword arguments to pass to the async function
Returns:
Any: The return value of the async function
"""
def run_in_new_loop():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(
async_func(
*args,
**kwargs,
)
)
finally:
loop.close()
with redirect_stdout(MCP_LOG_FILE):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result(timeout=10.0)
+4 -5
View File
@@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers
from gateway.common.config_manager import (
GatewayConfig,
GatewayConfigManager,
GuardrailsInHeader,
extract_guardrails_from_header,
)
from gateway.common.constants import (
CLIENT_TIMEOUT,
@@ -34,7 +34,6 @@ from gateway.integrations.guardrails import (
InstrumentedStreamingResponse,
Replacement,
check_guardrails,
preload_guardrails,
)
gateway = APIRouter()
@@ -70,7 +69,7 @@ async def anthropic_v1_messages_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
):
"""Proxy calls to the Anthropic APIs"""
headers = {
@@ -171,7 +170,7 @@ async def push_to_explorer(
"""Pushes the full trace to the Invariant Explorer"""
guardrails_execution_result = guardrails_execution_result or {}
annotations = create_annotations_from_guardrails_errors(
guardrails_execution_result.get("errors", []), action="block"
guardrails_execution_result.get("errors", [])
)
# Execute the logging guardrails before pushing to Explorer
@@ -181,7 +180,7 @@ async def push_to_explorer(
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
logging_guardrails_execution_result.get("errors", [])
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
+4 -5
View File
@@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers
from gateway.common.config_manager import (
GatewayConfig,
GatewayConfigManager,
GuardrailsInHeader,
extract_guardrails_from_header,
)
from gateway.common.constants import (
CLIENT_TIMEOUT,
@@ -31,7 +31,6 @@ from gateway.integrations.guardrails import (
InstrumentedResponse,
InstrumentedStreamingResponse,
Replacement,
preload_guardrails,
check_guardrails,
)
@@ -53,7 +52,7 @@ async def gemini_generate_content_gateway(
None, title="Response Format", description="Set to 'sse' for streaming"
),
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
) -> Response:
"""Proxy calls to the Gemini GenerateContent API"""
if endpoint not in ["generateContent", "streamGenerateContent"]:
@@ -413,7 +412,7 @@ async def push_to_explorer(
"""Pushes the full trace to the Invariant Explorer"""
guardrails_execution_result = guardrails_execution_result or {}
annotations = create_annotations_from_guardrails_errors(
guardrails_execution_result.get("errors", []), action="block"
guardrails_execution_result.get("errors", [])
)
# Execute the logging guardrails before pushing to Explorer
@@ -423,7 +422,7 @@ async def push_to_explorer(
response_json=response_json,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
logging_guardrails_execution_result.get("errors", [])
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
+4 -7
View File
@@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers
from gateway.common.config_manager import (
GatewayConfig,
GatewayConfigManager,
GuardrailsInHeader,
extract_guardrails_from_header,
)
from gateway.common.constants import (
CLIENT_TIMEOUT,
@@ -30,7 +30,6 @@ from gateway.integrations.guardrails import (
InstrumentedResponse,
InstrumentedStreamingResponse,
check_guardrails,
preload_guardrails,
)
gateway = APIRouter()
@@ -113,7 +112,7 @@ async def openai_chat_completions_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
) -> Response:
"""Proxy calls to the OpenAI APIs"""
headers = {
@@ -491,9 +490,7 @@ async def push_to_explorer(
# or if the guardrails check returned errors.
guardrails_execution_result = guardrails_execution_result or {}
guardrails_errors = guardrails_execution_result.get("errors", [])
annotations = create_annotations_from_guardrails_errors(
guardrails_errors, action="block"
)
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
@@ -501,7 +498,7 @@ async def push_to_explorer(
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
logging_guardrails_execution_result.get("errors", [])
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)