mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-25 00:04:01 +02:00
wip: mcp integratio
This commit is contained in:
@@ -82,11 +82,12 @@ class RequestContext:
|
||||
|
||||
# if additionally provided, extract separate API key to use with guardrailing service
|
||||
guardrail_service_authorization = None
|
||||
if (
|
||||
guardrail_authorization
|
||||
:= extract_guardrail_service_authorization_from_headers(request)
|
||||
):
|
||||
guardrail_service_authorization = guardrail_authorization
|
||||
if request is not None:
|
||||
if (
|
||||
guardrail_authorization
|
||||
:= extract_guardrail_service_authorization_from_headers(request)
|
||||
):
|
||||
guardrail_service_authorization = guardrail_authorization
|
||||
|
||||
return cls(
|
||||
request_json=request_json,
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import json
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
async def main():
|
||||
# Configure the MCP server parameters
|
||||
server_params = StdioServerParameters(
|
||||
command="/Users/luca/.local/bin/uv",
|
||||
args=[
|
||||
"run",
|
||||
"/Users/luca/Developer/invariant-gateway/gateway/routes/mcp-stdio.py",
|
||||
"python",
|
||||
"/Users/luca/Developer/hijack-mcp/mock-whatsapp.py",
|
||||
],
|
||||
env={"INVARIANT_API_KEY": "inv-..."},
|
||||
)
|
||||
|
||||
# Connect to the MCP server
|
||||
print("Connecting to WhatsApp MCP server...")
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
# Create the MCP client session
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
# Initialize the session
|
||||
print("Initializing session...")
|
||||
await session.initialize()
|
||||
|
||||
# First operation: List chats with their last messages
|
||||
print("\n==== LISTING CHATS ====")
|
||||
list_chats_args = {"include_last_message": True, "limit": 20, "page": 0}
|
||||
|
||||
# Call the list_chats tool
|
||||
list_chats_result = await session.call_tool("list_chats", list_chats_args)
|
||||
|
||||
# Process and display chat results
|
||||
print("Chat listing results:")
|
||||
chats = []
|
||||
for content in list_chats_result.content:
|
||||
if content.type == "text":
|
||||
try:
|
||||
# Parse the JSON chat data
|
||||
chat_data = json.loads(content.text)
|
||||
chats.append(chat_data)
|
||||
print(
|
||||
f"Chat: {chat_data.get('name', 'Unknown')} ({chat_data.get('jid', 'Unknown')})"
|
||||
)
|
||||
if "last_message" in chat_data:
|
||||
print(
|
||||
f" Last message: {chat_data['last_message'].get('text', 'No text')}"
|
||||
)
|
||||
print(
|
||||
f" Last active: {chat_data.get('last_active', 'Unknown')}"
|
||||
)
|
||||
print()
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error parsing chat data: {content.text}")
|
||||
|
||||
# Second operation: Send a message to a specific recipient
|
||||
print("\n==== SENDING MESSAGE ====")
|
||||
recipient = "+13241234123" # Using the recipient from your JSONL example
|
||||
|
||||
send_message_args = {
|
||||
"recipient": recipient,
|
||||
"message": "Hello! This is an automated message from the WhatsApp MCP client.",
|
||||
}
|
||||
|
||||
# Call the send_message tool
|
||||
send_result = await session.call_tool("send_message", send_message_args)
|
||||
|
||||
# Display send result
|
||||
print("Message send result:")
|
||||
for content in send_result.content:
|
||||
if content.type == "text":
|
||||
try:
|
||||
result_data = json.loads(content.text)
|
||||
if result_data.get("success"):
|
||||
print(
|
||||
f"✓ Success: {result_data.get('message', 'Message sent')}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✗ Error: {result_data.get('message', 'Unknown error')}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Raw response: {content.text}")
|
||||
|
||||
print("\nScript execution completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import sys
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import signal
|
||||
from invariant_sdk.client import Client
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
# ensure ../ is on the path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from common.request_context import RequestContext
|
||||
from integrations.guardrails import check_guardrails
|
||||
from common.guardrails import GuardrailRuleSet
|
||||
from integrations.explorer import (
|
||||
fetch_guardrails_from_explorer,
|
||||
)
|
||||
|
||||
# requires the 'INVARIANT_API_KEY' environment variable to be set
|
||||
client = Client()
|
||||
|
||||
# trace state (continously expanded on)
|
||||
EXPLORER_DATASET = "mcp-capture"
|
||||
TRACE = []
|
||||
TOOLS = []
|
||||
trace_id = None
|
||||
last_trace_length = 0
|
||||
|
||||
# guardrailing state
|
||||
GUARDRAILS: GuardrailRuleSet | None = None
|
||||
|
||||
# maps JSON RPC IDs to method names
|
||||
id_to_method_mapping = {}
|
||||
|
||||
# set stderr to be log.txt in the ~/.invariant/mcp.log
|
||||
os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True)
|
||||
LOG_OUT = open(
|
||||
os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log"),
|
||||
"a",
|
||||
buffering=1,
|
||||
)
|
||||
sys.stderr = LOG_OUT
|
||||
|
||||
|
||||
def print(*args, **kwargs):
|
||||
from builtins import print as builtins_print
|
||||
|
||||
builtins_print(*args, **kwargs, file=LOG_OUT, flush=True)
|
||||
|
||||
|
||||
def append_and_push_trace(message):
|
||||
global trace_id, TRACE, last_trace_length, tool_call_ids
|
||||
|
||||
try:
|
||||
if trace_id is None:
|
||||
TRACE.append(message)
|
||||
response = client.create_request_and_push_trace(
|
||||
messages=[TRACE],
|
||||
dataset="mcp-capture",
|
||||
metadata=[{"source": "mcp", "tools": TOOLS}],
|
||||
)
|
||||
trace_id = response.id[0]
|
||||
last_trace_length = len(TRACE)
|
||||
else:
|
||||
TRACE.append(message)
|
||||
client.create_request_and_append_messages(
|
||||
trace_id=trace_id, messages=TRACE[last_trace_length:]
|
||||
)
|
||||
last_trace_length = len(TRACE)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def check_blocking_guardrails(message, request):
|
||||
try:
|
||||
guardrails = fetch_guardrails(EXPLORER_DATASET)
|
||||
|
||||
context = RequestContext.create(
|
||||
request_json=request,
|
||||
dataset_name=EXPLORER_DATASET,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=guardrails,
|
||||
)
|
||||
|
||||
with redirect_stdout(LOG_OUT):
|
||||
return asyncio.run(
|
||||
check_guardrails(
|
||||
messages=TRACE + [message],
|
||||
guardrails=guardrails.blocking_guardrails,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
|
||||
def hook_tool_call(request):
|
||||
"""
|
||||
Hook function to intercept tool calls.
|
||||
Modify this function to change behavior for tool calls.
|
||||
Returns the potentially modified request.
|
||||
"""
|
||||
global trace_id, TRACE
|
||||
|
||||
tool_call = {
|
||||
"id": f"call_{request.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request["params"]["name"],
|
||||
"arguments": request["params"]["arguments"],
|
||||
},
|
||||
}
|
||||
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
|
||||
# Check for blocking guardrails
|
||||
result = check_blocking_guardrails(message, request)
|
||||
print("Guardrails check tool call result:", result)
|
||||
|
||||
append_and_push_trace(message)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def hook_tool_result(result):
|
||||
"""
|
||||
Hook function to intercept tool results.
|
||||
Modify this function to change behavior for tool results.
|
||||
Returns the potentially modified result.
|
||||
"""
|
||||
global TOOLS
|
||||
|
||||
method = id_to_method_mapping.get(result.get("id"))
|
||||
call_id = f"call_{result.get('id')}"
|
||||
|
||||
if method is None:
|
||||
return result
|
||||
elif method == "tools/call":
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": json.dumps(result.get("result").get("content")),
|
||||
"error": result.get("result").get("error"),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = check_blocking_guardrails(message, result)
|
||||
print("Guardrails check tool output result:", guardrailing_result)
|
||||
|
||||
if len(guardrailing_result["errors"]) > 0:
|
||||
result["result"]["content"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "[Invariant] Your MCP tool call was blocked for security reasons. Do not attempt to circumvent this block, rather explain to the user based on the following output what went wrong: \n"
|
||||
+ json.dumps(guardrailing_result["errors"]),
|
||||
}
|
||||
]
|
||||
|
||||
append_and_push_trace(message)
|
||||
|
||||
return result
|
||||
elif method == "tools/list":
|
||||
TOOLS = result.get("result").get("tools")
|
||||
return result
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def forward_stdout(process, buffer_size=1):
|
||||
"""Read from the process stdout, parse JSON chunks, and forward to sys.stdout"""
|
||||
buffer = b""
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
while True:
|
||||
chunk = process.stdout.read(buffer_size)
|
||||
if not chunk:
|
||||
break
|
||||
buffer += chunk
|
||||
|
||||
try:
|
||||
# Try parsing full JSON object from buffer
|
||||
text = buffer.decode("utf-8")
|
||||
obj = json.loads(text)
|
||||
|
||||
obj = hook_tool_result(obj)
|
||||
# clear the buffer
|
||||
buffer = b""
|
||||
|
||||
# Forward the original JSON to stdout
|
||||
json_output = json.dumps(obj).encode("utf-8") + b"\n"
|
||||
sys.stdout.buffer.write(json_output)
|
||||
sys.stdout.buffer.flush()
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# Wait for more data
|
||||
continue
|
||||
|
||||
|
||||
def forward_stderr(process, buffer_size=1):
|
||||
"""Read from the process stderr and write to sys.stderr"""
|
||||
for line in iter(lambda: process.stderr.read(buffer_size), b""):
|
||||
LOG_OUT.buffer.write(line)
|
||||
LOG_OUT.buffer.flush()
|
||||
|
||||
|
||||
def fetch_guardrails(dataset):
|
||||
# Use async fetch_guardrails_from_explorer in a thread
|
||||
return asyncio.run(
|
||||
fetch_guardrails_from_explorer(
|
||||
dataset, "Bearer " + os.getenv("INVARIANT_API_KEY")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <mcp-implementation> [args...]")
|
||||
sys.exit(1)
|
||||
|
||||
# Start the actual MCP implementation
|
||||
cmd = [sys.argv[1]] + sys.argv[2:]
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=0, # No buffering
|
||||
)
|
||||
|
||||
# Ensure we have an INVARIANT_API_KEY
|
||||
if "INVARIANT_API_KEY" not in os.environ:
|
||||
print("INVARIANT_API_KEY environment variable is not set.")
|
||||
sys.exit(1)
|
||||
|
||||
# Start threads to forward stdout and stderr
|
||||
stdout_thread = threading.Thread(
|
||||
target=forward_stdout, args=(process,), daemon=True
|
||||
)
|
||||
stderr_thread = threading.Thread(
|
||||
target=forward_stderr, args=(process,), daemon=True
|
||||
)
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
|
||||
# Handle forwarding stdin and intercept tool calls
|
||||
try:
|
||||
current_chunk = b""
|
||||
|
||||
while True:
|
||||
data = sys.stdin.buffer.read(1)
|
||||
current_chunk += data
|
||||
|
||||
if not data:
|
||||
break
|
||||
|
||||
# Try to decode and parse as JSON to check for tool calls
|
||||
try:
|
||||
text = current_chunk.decode("utf-8")
|
||||
obj = json.loads(text)
|
||||
# clear the current chunk
|
||||
current_chunk = b""
|
||||
|
||||
if obj.get("method") is not None:
|
||||
id_to_method_mapping[obj.get("id")] = obj.get("method")
|
||||
|
||||
# Check if this is a tool call request
|
||||
if obj.get("method") == "tools/call":
|
||||
# Intercept and potentially modify the request
|
||||
obj = hook_tool_call(obj)
|
||||
# Convert back to bytes
|
||||
data = json.dumps(obj).encode("utf-8")
|
||||
|
||||
# Forward to the process
|
||||
process.stdin.write(data + b"\n")
|
||||
process.stdin.flush()
|
||||
continue
|
||||
else:
|
||||
process.stdin.write(json.dumps(obj).encode("utf-8") + b"\n")
|
||||
process.stdin.flush()
|
||||
continue
|
||||
except Exception:
|
||||
# Not a complete or valid JSON, just pass through
|
||||
pass
|
||||
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
# Clean termination on Ctrl+C
|
||||
process.terminate()
|
||||
|
||||
# Wait for process to terminate
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
|
||||
# Handle signals to ensure clean shutdown
|
||||
def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
main()
|
||||
@@ -53,7 +53,7 @@ def make_cors_response(request: Request, allow_methods: str) -> Response:
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": request.headers.get("origin", "*"),
|
||||
"Access-Control-Allow-Methods": f"{allow_methods}, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
||||
"Access-Control-Allow-Headers": "Authorization, Content-Type, Invariant-Authorization, Invariant-Guardrails, Invariant-Guardrails-Authorization, Origin",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user