wip: mcp integratio

This commit is contained in:
Luca Beurer-Kellner
2025-04-07 10:38:34 +02:00
parent 6b6f33bde6
commit fd9f65aabd
4 changed files with 414 additions and 6 deletions
+6 -5
View File
@@ -82,11 +82,12 @@ class RequestContext:
# if additionally provided, extract separate API key to use with guardrailing service
guardrail_service_authorization = None
if (
guardrail_authorization
:= extract_guardrail_service_authorization_from_headers(request)
):
guardrail_service_authorization = guardrail_authorization
if request is not None:
if (
guardrail_authorization
:= extract_guardrail_service_authorization_from_headers(request)
):
guardrail_service_authorization = guardrail_authorization
return cls(
request_json=request_json,
+93
View File
@@ -0,0 +1,93 @@
#!/usr/bin/env python3
import asyncio
import json
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
async def main():
# Configure the MCP server parameters
server_params = StdioServerParameters(
command="/Users/luca/.local/bin/uv",
args=[
"run",
"/Users/luca/Developer/invariant-gateway/gateway/routes/mcp-stdio.py",
"python",
"/Users/luca/Developer/hijack-mcp/mock-whatsapp.py",
],
env={"INVARIANT_API_KEY": "inv-..."},
)
# Connect to the MCP server
print("Connecting to WhatsApp MCP server...")
async with stdio_client(server_params) as (read_stream, write_stream):
# Create the MCP client session
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
print("Initializing session...")
await session.initialize()
# First operation: List chats with their last messages
print("\n==== LISTING CHATS ====")
list_chats_args = {"include_last_message": True, "limit": 20, "page": 0}
# Call the list_chats tool
list_chats_result = await session.call_tool("list_chats", list_chats_args)
# Process and display chat results
print("Chat listing results:")
chats = []
for content in list_chats_result.content:
if content.type == "text":
try:
# Parse the JSON chat data
chat_data = json.loads(content.text)
chats.append(chat_data)
print(
f"Chat: {chat_data.get('name', 'Unknown')} ({chat_data.get('jid', 'Unknown')})"
)
if "last_message" in chat_data:
print(
f" Last message: {chat_data['last_message'].get('text', 'No text')}"
)
print(
f" Last active: {chat_data.get('last_active', 'Unknown')}"
)
print()
except json.JSONDecodeError:
print(f"Error parsing chat data: {content.text}")
# Second operation: Send a message to a specific recipient
print("\n==== SENDING MESSAGE ====")
recipient = "+13241234123" # Using the recipient from your JSONL example
send_message_args = {
"recipient": recipient,
"message": "Hello! This is an automated message from the WhatsApp MCP client.",
}
# Call the send_message tool
send_result = await session.call_tool("send_message", send_message_args)
# Display send result
print("Message send result:")
for content in send_result.content:
if content.type == "text":
try:
result_data = json.loads(content.text)
if result_data.get("success"):
print(
f"✓ Success: {result_data.get('message', 'Message sent')}"
)
else:
print(
f"✗ Error: {result_data.get('message', 'Unknown error')}"
)
except json.JSONDecodeError:
print(f"Raw response: {content.text}")
print("\nScript execution completed")
if __name__ == "__main__":
asyncio.run(main())
+314
View File
@@ -0,0 +1,314 @@
#!/usr/bin/env python3
import asyncio
import sys
import subprocess
import json
import os
import threading
import signal
from invariant_sdk.client import Client
from contextlib import redirect_stdout
# ensure ../ is on the path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from common.request_context import RequestContext
from integrations.guardrails import check_guardrails
from common.guardrails import GuardrailRuleSet
from integrations.explorer import (
fetch_guardrails_from_explorer,
)
# requires the 'INVARIANT_API_KEY' environment variable to be set
client = Client()
# trace state (continously expanded on)
EXPLORER_DATASET = "mcp-capture"
TRACE = []
TOOLS = []
trace_id = None
last_trace_length = 0
# guardrailing state
GUARDRAILS: GuardrailRuleSet | None = None
# maps JSON RPC IDs to method names
id_to_method_mapping = {}
# set stderr to be log.txt in the ~/.invariant/mcp.log
os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True)
LOG_OUT = open(
os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log"),
"a",
buffering=1,
)
sys.stderr = LOG_OUT
def print(*args, **kwargs):
from builtins import print as builtins_print
builtins_print(*args, **kwargs, file=LOG_OUT, flush=True)
def append_and_push_trace(message):
global trace_id, TRACE, last_trace_length, tool_call_ids
try:
if trace_id is None:
TRACE.append(message)
response = client.create_request_and_push_trace(
messages=[TRACE],
dataset="mcp-capture",
metadata=[{"source": "mcp", "tools": TOOLS}],
)
trace_id = response.id[0]
last_trace_length = len(TRACE)
else:
TRACE.append(message)
client.create_request_and_append_messages(
trace_id=trace_id, messages=TRACE[last_trace_length:]
)
last_trace_length = len(TRACE)
except Exception as e:
import traceback
print(traceback.format_exc())
def check_blocking_guardrails(message, request):
try:
guardrails = fetch_guardrails(EXPLORER_DATASET)
context = RequestContext.create(
request_json=request,
dataset_name=EXPLORER_DATASET,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=guardrails,
)
with redirect_stdout(LOG_OUT):
return asyncio.run(
check_guardrails(
messages=TRACE + [message],
guardrails=guardrails.blocking_guardrails,
context=context,
)
)
except Exception as e:
import traceback
print(traceback.format_exc())
raise e
def hook_tool_call(request):
"""
Hook function to intercept tool calls.
Modify this function to change behavior for tool calls.
Returns the potentially modified request.
"""
global trace_id, TRACE
tool_call = {
"id": f"call_{request.get('id')}",
"type": "function",
"function": {
"name": request["params"]["name"],
"arguments": request["params"]["arguments"],
},
}
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
# Check for blocking guardrails
result = check_blocking_guardrails(message, request)
print("Guardrails check tool call result:", result)
append_and_push_trace(message)
return request
def hook_tool_result(result):
"""
Hook function to intercept tool results.
Modify this function to change behavior for tool results.
Returns the potentially modified result.
"""
global TOOLS
method = id_to_method_mapping.get(result.get("id"))
call_id = f"call_{result.get('id')}"
if method is None:
return result
elif method == "tools/call":
message = {
"role": "tool",
"content": json.dumps(result.get("result").get("content")),
"error": result.get("result").get("error"),
"tool_call_id": call_id,
}
# Check for blocking guardrails
guardrailing_result = check_blocking_guardrails(message, result)
print("Guardrails check tool output result:", guardrailing_result)
if len(guardrailing_result["errors"]) > 0:
result["result"]["content"] = [
{
"type": "text",
"text": "[Invariant] Your MCP tool call was blocked for security reasons. Do not attempt to circumvent this block, rather explain to the user based on the following output what went wrong: \n"
+ json.dumps(guardrailing_result["errors"]),
}
]
append_and_push_trace(message)
return result
elif method == "tools/list":
TOOLS = result.get("result").get("tools")
return result
else:
return result
def forward_stdout(process, buffer_size=1):
"""Read from the process stdout, parse JSON chunks, and forward to sys.stdout"""
buffer = b""
decoder = json.JSONDecoder()
while True:
chunk = process.stdout.read(buffer_size)
if not chunk:
break
buffer += chunk
try:
# Try parsing full JSON object from buffer
text = buffer.decode("utf-8")
obj = json.loads(text)
obj = hook_tool_result(obj)
# clear the buffer
buffer = b""
# Forward the original JSON to stdout
json_output = json.dumps(obj).encode("utf-8") + b"\n"
sys.stdout.buffer.write(json_output)
sys.stdout.buffer.flush()
except (json.JSONDecodeError, UnicodeDecodeError):
# Wait for more data
continue
def forward_stderr(process, buffer_size=1):
"""Read from the process stderr and write to sys.stderr"""
for line in iter(lambda: process.stderr.read(buffer_size), b""):
LOG_OUT.buffer.write(line)
LOG_OUT.buffer.flush()
def fetch_guardrails(dataset):
# Use async fetch_guardrails_from_explorer in a thread
return asyncio.run(
fetch_guardrails_from_explorer(
dataset, "Bearer " + os.getenv("INVARIANT_API_KEY")
)
)
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <mcp-implementation> [args...]")
sys.exit(1)
# Start the actual MCP implementation
cmd = [sys.argv[1]] + sys.argv[2:]
process = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=0, # No buffering
)
# Ensure we have an INVARIANT_API_KEY
if "INVARIANT_API_KEY" not in os.environ:
print("INVARIANT_API_KEY environment variable is not set.")
sys.exit(1)
# Start threads to forward stdout and stderr
stdout_thread = threading.Thread(
target=forward_stdout, args=(process,), daemon=True
)
stderr_thread = threading.Thread(
target=forward_stderr, args=(process,), daemon=True
)
stdout_thread.start()
stderr_thread.start()
# Handle forwarding stdin and intercept tool calls
try:
current_chunk = b""
while True:
data = sys.stdin.buffer.read(1)
current_chunk += data
if not data:
break
# Try to decode and parse as JSON to check for tool calls
try:
text = current_chunk.decode("utf-8")
obj = json.loads(text)
# clear the current chunk
current_chunk = b""
if obj.get("method") is not None:
id_to_method_mapping[obj.get("id")] = obj.get("method")
# Check if this is a tool call request
if obj.get("method") == "tools/call":
# Intercept and potentially modify the request
obj = hook_tool_call(obj)
# Convert back to bytes
data = json.dumps(obj).encode("utf-8")
# Forward to the process
process.stdin.write(data + b"\n")
process.stdin.flush()
continue
else:
process.stdin.write(json.dumps(obj).encode("utf-8") + b"\n")
process.stdin.flush()
continue
except Exception:
# Not a complete or valid JSON, just pass through
pass
except BrokenPipeError:
pass
except KeyboardInterrupt:
# Clean termination on Ctrl+C
process.terminate()
# Wait for process to terminate
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
# Handle signals to ensure clean shutdown
def signal_handler(sig, frame):
sys.exit(0)
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
main()
+1 -1
View File
@@ -53,7 +53,7 @@ def make_cors_response(request: Request, allow_methods: str) -> Response:
headers={
"Access-Control-Allow-Origin": request.headers.get("origin", "*"),
"Access-Control-Allow-Methods": f"{allow_methods}, OPTIONS",
"Access-Control-Allow-Headers": "Authorization, Content-Type",
"Access-Control-Allow-Headers": "Authorization, Content-Type, Invariant-Authorization, Invariant-Guardrails, Invariant-Guardrails-Authorization, Origin",
"Access-Control-Max-Age": "86400",
},
)