From fd9f65aabd033dd163b5470c40c998b0be2df5f7 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Mon, 7 Apr 2025 10:38:34 +0200 Subject: [PATCH] wip: mcp integratio --- gateway/common/request_context.py | 11 +- gateway/routes/client.py | 93 +++++++++ gateway/routes/mcp-stdio.py | 314 ++++++++++++++++++++++++++++++ gateway/routes/open_ai.py | 2 +- 4 files changed, 414 insertions(+), 6 deletions(-) create mode 100644 gateway/routes/client.py create mode 100644 gateway/routes/mcp-stdio.py diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py index 68cb1d7..a51de72 100644 --- a/gateway/common/request_context.py +++ b/gateway/common/request_context.py @@ -82,11 +82,12 @@ class RequestContext: # if additionally provided, extract separate API key to use with guardrailing service guardrail_service_authorization = None - if ( - guardrail_authorization - := extract_guardrail_service_authorization_from_headers(request) - ): - guardrail_service_authorization = guardrail_authorization + if request is not None: + if ( + guardrail_authorization + := extract_guardrail_service_authorization_from_headers(request) + ): + guardrail_service_authorization = guardrail_authorization return cls( request_json=request_json, diff --git a/gateway/routes/client.py b/gateway/routes/client.py new file mode 100644 index 0000000..16ff4a1 --- /dev/null +++ b/gateway/routes/client.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +import asyncio +import json +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.stdio import stdio_client + + +async def main(): + # Configure the MCP server parameters + server_params = StdioServerParameters( + command="/Users/luca/.local/bin/uv", + args=[ + "run", + "/Users/luca/Developer/invariant-gateway/gateway/routes/mcp-stdio.py", + "python", + "/Users/luca/Developer/hijack-mcp/mock-whatsapp.py", + ], + env={"INVARIANT_API_KEY": "inv-..."}, + ) + + # Connect to the MCP server + print("Connecting to WhatsApp MCP server...") + async with stdio_client(server_params) as (read_stream, write_stream): + # Create the MCP client session + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + print("Initializing session...") + await session.initialize() + + # First operation: List chats with their last messages + print("\n==== LISTING CHATS ====") + list_chats_args = {"include_last_message": True, "limit": 20, "page": 0} + + # Call the list_chats tool + list_chats_result = await session.call_tool("list_chats", list_chats_args) + + # Process and display chat results + print("Chat listing results:") + chats = [] + for content in list_chats_result.content: + if content.type == "text": + try: + # Parse the JSON chat data + chat_data = json.loads(content.text) + chats.append(chat_data) + print( + f"Chat: {chat_data.get('name', 'Unknown')} ({chat_data.get('jid', 'Unknown')})" + ) + if "last_message" in chat_data: + print( + f" Last message: {chat_data['last_message'].get('text', 'No text')}" + ) + print( + f" Last active: {chat_data.get('last_active', 'Unknown')}" + ) + print() + except json.JSONDecodeError: + print(f"Error parsing chat data: {content.text}") + + # Second operation: Send a message to a specific recipient + print("\n==== SENDING MESSAGE ====") + recipient = "+13241234123" # Using the recipient from your JSONL example + + send_message_args = { + "recipient": recipient, + "message": "Hello! This is an automated message from the WhatsApp MCP client.", + } + + # Call the send_message tool + send_result = await session.call_tool("send_message", send_message_args) + + # Display send result + print("Message send result:") + for content in send_result.content: + if content.type == "text": + try: + result_data = json.loads(content.text) + if result_data.get("success"): + print( + f"✓ Success: {result_data.get('message', 'Message sent')}" + ) + else: + print( + f"✗ Error: {result_data.get('message', 'Unknown error')}" + ) + except json.JSONDecodeError: + print(f"Raw response: {content.text}") + + print("\nScript execution completed") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/gateway/routes/mcp-stdio.py b/gateway/routes/mcp-stdio.py new file mode 100644 index 0000000..35ff607 --- /dev/null +++ b/gateway/routes/mcp-stdio.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +import asyncio +import sys +import subprocess +import json +import os +import threading +import signal +from invariant_sdk.client import Client + +from contextlib import redirect_stdout + +# ensure ../ is on the path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from common.request_context import RequestContext +from integrations.guardrails import check_guardrails +from common.guardrails import GuardrailRuleSet +from integrations.explorer import ( + fetch_guardrails_from_explorer, +) + +# requires the 'INVARIANT_API_KEY' environment variable to be set +client = Client() + +# trace state (continously expanded on) +EXPLORER_DATASET = "mcp-capture" +TRACE = [] +TOOLS = [] +trace_id = None +last_trace_length = 0 + +# guardrailing state +GUARDRAILS: GuardrailRuleSet | None = None + +# maps JSON RPC IDs to method names +id_to_method_mapping = {} + +# set stderr to be log.txt in the ~/.invariant/mcp.log +os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True) +LOG_OUT = open( + os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log"), + "a", + buffering=1, +) +sys.stderr = LOG_OUT + + +def print(*args, **kwargs): + from builtins import print as builtins_print + + builtins_print(*args, **kwargs, file=LOG_OUT, flush=True) + + +def append_and_push_trace(message): + global trace_id, TRACE, last_trace_length, tool_call_ids + + try: + if trace_id is None: + TRACE.append(message) + response = client.create_request_and_push_trace( + messages=[TRACE], + dataset="mcp-capture", + metadata=[{"source": "mcp", "tools": TOOLS}], + ) + trace_id = response.id[0] + last_trace_length = len(TRACE) + else: + TRACE.append(message) + client.create_request_and_append_messages( + trace_id=trace_id, messages=TRACE[last_trace_length:] + ) + last_trace_length = len(TRACE) + except Exception as e: + import traceback + + print(traceback.format_exc()) + + +def check_blocking_guardrails(message, request): + try: + guardrails = fetch_guardrails(EXPLORER_DATASET) + + context = RequestContext.create( + request_json=request, + dataset_name=EXPLORER_DATASET, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + guardrails=guardrails, + ) + + with redirect_stdout(LOG_OUT): + return asyncio.run( + check_guardrails( + messages=TRACE + [message], + guardrails=guardrails.blocking_guardrails, + context=context, + ) + ) + except Exception as e: + import traceback + + print(traceback.format_exc()) + raise e + + +def hook_tool_call(request): + """ + Hook function to intercept tool calls. + Modify this function to change behavior for tool calls. + Returns the potentially modified request. + """ + global trace_id, TRACE + + tool_call = { + "id": f"call_{request.get('id')}", + "type": "function", + "function": { + "name": request["params"]["name"], + "arguments": request["params"]["arguments"], + }, + } + + message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} + + # Check for blocking guardrails + result = check_blocking_guardrails(message, request) + print("Guardrails check tool call result:", result) + + append_and_push_trace(message) + + return request + + +def hook_tool_result(result): + """ + Hook function to intercept tool results. + Modify this function to change behavior for tool results. + Returns the potentially modified result. + """ + global TOOLS + + method = id_to_method_mapping.get(result.get("id")) + call_id = f"call_{result.get('id')}" + + if method is None: + return result + elif method == "tools/call": + message = { + "role": "tool", + "content": json.dumps(result.get("result").get("content")), + "error": result.get("result").get("error"), + "tool_call_id": call_id, + } + + # Check for blocking guardrails + guardrailing_result = check_blocking_guardrails(message, result) + print("Guardrails check tool output result:", guardrailing_result) + + if len(guardrailing_result["errors"]) > 0: + result["result"]["content"] = [ + { + "type": "text", + "text": "[Invariant] Your MCP tool call was blocked for security reasons. Do not attempt to circumvent this block, rather explain to the user based on the following output what went wrong: \n" + + json.dumps(guardrailing_result["errors"]), + } + ] + + append_and_push_trace(message) + + return result + elif method == "tools/list": + TOOLS = result.get("result").get("tools") + return result + else: + return result + + +def forward_stdout(process, buffer_size=1): + """Read from the process stdout, parse JSON chunks, and forward to sys.stdout""" + buffer = b"" + decoder = json.JSONDecoder() + + while True: + chunk = process.stdout.read(buffer_size) + if not chunk: + break + buffer += chunk + + try: + # Try parsing full JSON object from buffer + text = buffer.decode("utf-8") + obj = json.loads(text) + + obj = hook_tool_result(obj) + # clear the buffer + buffer = b"" + + # Forward the original JSON to stdout + json_output = json.dumps(obj).encode("utf-8") + b"\n" + sys.stdout.buffer.write(json_output) + sys.stdout.buffer.flush() + except (json.JSONDecodeError, UnicodeDecodeError): + # Wait for more data + continue + + +def forward_stderr(process, buffer_size=1): + """Read from the process stderr and write to sys.stderr""" + for line in iter(lambda: process.stderr.read(buffer_size), b""): + LOG_OUT.buffer.write(line) + LOG_OUT.buffer.flush() + + +def fetch_guardrails(dataset): + # Use async fetch_guardrails_from_explorer in a thread + return asyncio.run( + fetch_guardrails_from_explorer( + dataset, "Bearer " + os.getenv("INVARIANT_API_KEY") + ) + ) + + +def main(): + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [args...]") + sys.exit(1) + + # Start the actual MCP implementation + cmd = [sys.argv[1]] + sys.argv[2:] + process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=0, # No buffering + ) + + # Ensure we have an INVARIANT_API_KEY + if "INVARIANT_API_KEY" not in os.environ: + print("INVARIANT_API_KEY environment variable is not set.") + sys.exit(1) + + # Start threads to forward stdout and stderr + stdout_thread = threading.Thread( + target=forward_stdout, args=(process,), daemon=True + ) + stderr_thread = threading.Thread( + target=forward_stderr, args=(process,), daemon=True + ) + stdout_thread.start() + stderr_thread.start() + + # Handle forwarding stdin and intercept tool calls + try: + current_chunk = b"" + + while True: + data = sys.stdin.buffer.read(1) + current_chunk += data + + if not data: + break + + # Try to decode and parse as JSON to check for tool calls + try: + text = current_chunk.decode("utf-8") + obj = json.loads(text) + # clear the current chunk + current_chunk = b"" + + if obj.get("method") is not None: + id_to_method_mapping[obj.get("id")] = obj.get("method") + + # Check if this is a tool call request + if obj.get("method") == "tools/call": + # Intercept and potentially modify the request + obj = hook_tool_call(obj) + # Convert back to bytes + data = json.dumps(obj).encode("utf-8") + + # Forward to the process + process.stdin.write(data + b"\n") + process.stdin.flush() + continue + else: + process.stdin.write(json.dumps(obj).encode("utf-8") + b"\n") + process.stdin.flush() + continue + except Exception: + # Not a complete or valid JSON, just pass through + pass + + except BrokenPipeError: + pass + except KeyboardInterrupt: + # Clean termination on Ctrl+C + process.terminate() + + # Wait for process to terminate + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + +# Handle signals to ensure clean shutdown +def signal_handler(sig, frame): + sys.exit(0) + + +if __name__ == "__main__": + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + main() diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 5695c03..f03c276 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -53,7 +53,7 @@ def make_cors_response(request: Request, allow_methods: str) -> Response: headers={ "Access-Control-Allow-Origin": request.headers.get("origin", "*"), "Access-Control-Allow-Methods": f"{allow_methods}, OPTIONS", - "Access-Control-Allow-Headers": "Authorization, Content-Type", + "Access-Control-Allow-Headers": "Authorization, Content-Type, Invariant-Authorization, Invariant-Guardrails, Invariant-Guardrails-Authorization, Origin", "Access-Control-Max-Age": "86400", }, )