mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-22 23:09:44 +02:00
Fix asyncio and threading. Dedupe annotations before pushing. Add README.
This commit is contained in:
+30
-12
@@ -1,34 +1,52 @@
|
||||
"""Script is used to run actions using the Invariant Gateway."""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from gateway.mcp import mcp
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for the Invariant Gateway."""
|
||||
# Handle signals to ensure clean shutdown
|
||||
def signal_handler(sig, frame):
|
||||
"""Handle signals for graceful shutdown."""
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def print_help():
|
||||
"""Prints the help message."""
|
||||
actions = {
|
||||
"mcp": "Runs the Invariant Gateway against MCP (Model Context Protocol) servers with guardrailing and push to Explorer features",
|
||||
"llm": "Runs the Invariant Gateway against LLM providers with guardrailing and push to Explorer features",
|
||||
"help": "Shows this help message",
|
||||
"mcp": "Runs the Invariant Gateway against MCP (Model Context Protocol) servers with guardrailing and push to Explorer features.",
|
||||
"llm": "Runs the Invariant Gateway against LLM providers with guardrailing and push to Explorer features. Not implemented yet.",
|
||||
"help": "Shows this help message.",
|
||||
}
|
||||
|
||||
def _help():
|
||||
"""_prints the help message."""
|
||||
for verb, description in actions.items():
|
||||
print(f"{verb}: {description}")
|
||||
for verb, description in actions.items():
|
||||
print(f"{verb}: {description}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point for the Invariant Gateway."""
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
_help()
|
||||
print_help()
|
||||
sys.exit(1)
|
||||
|
||||
verb = sys.argv[1]
|
||||
if verb == "mcp":
|
||||
return mcp.execute(sys.argv[2:])
|
||||
print("[MCP] Running Invariant Gateway against MCP servers...")
|
||||
return asyncio.run(mcp.execute(sys.argv[2:]))
|
||||
if verb == "llm":
|
||||
print("[gateway/__main__.py] 'llm' action is not implemented yet.")
|
||||
return 1
|
||||
if verb == "help":
|
||||
_help()
|
||||
print_help()
|
||||
return 0
|
||||
print(f"[gateway/__main__.py] Unknown action: {verb}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@@ -93,7 +93,9 @@ class GatewayConfigManager:
|
||||
return local_config
|
||||
|
||||
|
||||
async def GuardrailsInHeader(request: fastapi.Request) -> Optional[GuardrailRuleSet]:
|
||||
async def extract_guardrails_from_header(
|
||||
request: fastapi.Request,
|
||||
) -> Optional[GuardrailRuleSet]:
|
||||
"""
|
||||
Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding
|
||||
GuardrailRuleSet. If no guardrails are provided, returns None.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Utility functions for the Invariant explorer."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
|
||||
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
@@ -16,17 +17,20 @@ DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
def create_annotations_from_guardrails_errors(
|
||||
guardrails_errors: List[dict], action: str = "block"
|
||||
guardrails_errors: List[dict],
|
||||
) -> List[AnnotationCreate]:
|
||||
"""Create Explorer annotations from the guardrails errors."""
|
||||
annotations = []
|
||||
|
||||
def _remove_prefixes(ranges: list[str]) -> list[str]:
|
||||
def _pick_most_specific_ranges(ranges: list[str]) -> list[str]:
|
||||
"""
|
||||
Remove prefixes from the list of ranges.
|
||||
Remove redundant prefixes from the list of ranges.
|
||||
|
||||
If the ranges are ['messages.2', 'messages.2.content:25-30', 'messages.2.content']
|
||||
then this returns ['messages.2.content:25-30'].
|
||||
|
||||
This picks the most specific subset of the ranges and removes the rest. If some
|
||||
range is a proper prefix of another range, it is removed.
|
||||
"""
|
||||
ranges = sorted(ranges, key=len)
|
||||
result = []
|
||||
@@ -44,7 +48,7 @@ def create_annotations_from_guardrails_errors(
|
||||
|
||||
for error in guardrails_errors:
|
||||
content = error.get("args")[0]
|
||||
filtered_ranges = _remove_prefixes(list(error.get("ranges", [])))
|
||||
filtered_ranges = _pick_most_specific_ranges(list(error.get("ranges", [])))
|
||||
for r in filtered_ranges:
|
||||
annotations.append(
|
||||
AnnotationCreate(
|
||||
@@ -61,10 +65,40 @@ def create_annotations_from_guardrails_errors(
|
||||
},
|
||||
)
|
||||
)
|
||||
return annotations
|
||||
# Remove duplicates
|
||||
# TODO: Rely on the __eq__ and __hash__ methods of the AnnotationCreate class
|
||||
# to remove duplicates instead of using a custom function.
|
||||
# This is a temporary solution until the Invariant SDK is updated.
|
||||
return remove_duplicates(annotations)
|
||||
|
||||
|
||||
def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]:
|
||||
"""
|
||||
Remove duplicate annotations based on content, address, and extra_metadata.
|
||||
|
||||
Two annotations are considered duplicates if they have the same content,
|
||||
address, and extra_metadata.
|
||||
"""
|
||||
unique_annotations = []
|
||||
seen = set()
|
||||
|
||||
for annotation in annotations:
|
||||
# Convert the entire extra_metadata dict to a JSON string
|
||||
# This creates a hashable representation regardless of nested content
|
||||
metadata_str = json.dumps(annotation.extra_metadata, sort_keys=True)
|
||||
|
||||
# Create a unique identifier using all three fields
|
||||
unique_key = (annotation.content, annotation.address, metadata_str)
|
||||
|
||||
if unique_key not in seen:
|
||||
seen.add(unique_key)
|
||||
unique_annotations.append(annotation)
|
||||
|
||||
return unique_annotations
|
||||
|
||||
|
||||
def get_explorer_api_url() -> str:
|
||||
"""Get the Invariant Explorer API URL from the environment variable."""
|
||||
return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
This is a work in progress implementation for MCP (Model Context Protocol) with the Gateway.
|
||||
|
||||
This repository will be pushed to PyPi and then using it with the MCP config file will be simpler.
|
||||
|
||||
For now if the original MCP config file looks like:
|
||||
|
||||
```
|
||||
{
|
||||
"mcpServers": {
|
||||
"weather": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"--directory",
|
||||
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
|
||||
"run",
|
||||
"weather.py"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You need to:
|
||||
|
||||
1. Checkout the invariant-gatway repo.
|
||||
2. Run `python -m build`. This will generate a .whl file in dist.
|
||||
3. Modify the MCP config like this:
|
||||
|
||||
```
|
||||
{
|
||||
"mcpServers": {
|
||||
"weather": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"--refresh",
|
||||
"--from",
|
||||
"/ABSOLUTE/PATH/TO/INVARIANT_GATEWAY_REPO/dist/invariant_gateway-0.0.1-py3-none-any.whl",
|
||||
"invariant-gateway",
|
||||
"mcp",
|
||||
"--dataset-name",
|
||||
"weather-testing",
|
||||
"--push-explorer",
|
||||
"--exec",
|
||||
"uv",
|
||||
"--directory",
|
||||
"/Users/hemang/Sdk/mcp/weather",
|
||||
"run",
|
||||
"weather.py"
|
||||
],
|
||||
"env": {
|
||||
"INVARIANT_API_KEY": "<Add Invariant API key here>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This moves the original `command` and `args` to the `args` list after the `--exec` flag.
|
||||
|
||||
All args before the `--exec` flag are relevant to the Invariant MCP gateway. These include:
|
||||
|
||||
- `--dataset-name`: With this you can specify the name of the dataset. The guardrails are pulled from this dataset.
|
||||
- `--push-explorer`: With this you can specify if you want to push the annotated traces to the Invariant Explorer.
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Cusym log configuration."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from builtins import print as builtins_print
|
||||
|
||||
os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True)
|
||||
MCP_LOG_FILE = open(
|
||||
os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log"),
|
||||
"a",
|
||||
buffering=1,
|
||||
)
|
||||
sys.stderr = MCP_LOG_FILE
|
||||
|
||||
|
||||
def mcp_log(*args, **kwargs) -> None:
|
||||
"""Custom print function to redirect output to log_out."""
|
||||
builtins_print(*args, **kwargs, file=MCP_LOG_FILE, flush=True)
|
||||
+313
-188
@@ -1,99 +1,181 @@
|
||||
"""Gateway for MCP (Model Context Protocol) integration with Invariant."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import signal
|
||||
|
||||
from builtins import print as builtins_print
|
||||
from contextlib import redirect_stdout
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest
|
||||
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.integrations.explorer import (
|
||||
fetch_guardrails_from_explorer,
|
||||
)
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
from gateway.mcp.mcp_context import McpContext
|
||||
from gateway.mcp.task_utils import run_task_in_background, run_task_sync
|
||||
|
||||
MCP_METHOD = "method"
|
||||
UTF_8_ENCODING = "utf-8"
|
||||
MCP_TOOL_CALL = "tools/call"
|
||||
MCP_LIST_TOOLS = "tools/list"
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE = """
|
||||
[Invariant Guardrails] The MCP tool call was blocked for security reasons.
|
||||
Do not attempt to circumvent this block, rather explain to the user based
|
||||
on the following output what went wrong: %s
|
||||
"""
|
||||
|
||||
|
||||
def custom_print(ctx, *args, **kwargs):
|
||||
"""Custom print function to redirect output to log_out."""
|
||||
builtins_print(*args, **kwargs, file=ctx.log_out, flush=True)
|
||||
def write_as_utf8_bytes(data: dict) -> bytes:
|
||||
"""Serializes dict to bytes using UTF-8 encoding."""
|
||||
return json.dumps(data).encode(UTF_8_ENCODING) + b"\n"
|
||||
|
||||
|
||||
def append_and_push_trace(ctx, message):
|
||||
def deduplicate_annotations(ctx: McpContext, new_annotations: list) -> list:
|
||||
"""Deduplicate new_annotations using the annotations in the context."""
|
||||
deduped_annotations = []
|
||||
for annotation in new_annotations:
|
||||
# Check if an annotation with the same content and address exists in ctx.annotations
|
||||
# TODO: Rely on the __eq__ method of the AnnotationCreate class directly via not in
|
||||
# to remove duplicates instead of using a custom logic.
|
||||
# This is a temporary solution until the Invariant SDK is updated.
|
||||
is_duplicate = False
|
||||
for ctx_annotation in ctx.annotations:
|
||||
if (
|
||||
annotation.content == ctx_annotation.content
|
||||
and annotation.address == ctx_annotation.address
|
||||
and annotation.extra_metadata == ctx_annotation.extra_metadata
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
deduped_annotations.append(annotation)
|
||||
|
||||
return deduped_annotations
|
||||
|
||||
|
||||
def check_if_new_errors(ctx: McpContext, guardrails_result: dict) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in ctx.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def append_and_push_trace(
|
||||
ctx: McpContext, message: dict, guardrails_result: dict
|
||||
) -> None:
|
||||
"""
|
||||
Append a message to the trace if it exists or create a new one
|
||||
and push it to the Invariant Explorer.
|
||||
|
||||
This function runs asynchronously in the background.
|
||||
"""
|
||||
|
||||
annotations = []
|
||||
if guardrails_result and guardrails_result.get("errors", []):
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result["errors"]
|
||||
)
|
||||
|
||||
if ctx.guardrails.logging_guardrails:
|
||||
logging_guardrails_check_result = get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.LOG
|
||||
)
|
||||
if logging_guardrails_check_result and logging_guardrails_check_result.get(
|
||||
"errors", []
|
||||
):
|
||||
annotations.extend(
|
||||
create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_check_result["errors"]
|
||||
)
|
||||
)
|
||||
deduplicated_annotations = deduplicate_annotations(ctx, annotations)
|
||||
|
||||
try:
|
||||
# If the trace_id is None, create a new trace with the messages.
|
||||
# Otherwise, append the message to the existing trace.
|
||||
client = AsyncClient()
|
||||
if ctx.trace_id is None:
|
||||
ctx.trace.append(message)
|
||||
response = ctx.client.create_request_and_push_trace(
|
||||
messages=[ctx.trace],
|
||||
dataset=ctx.explorer_dataset,
|
||||
metadata=[{"source": "mcp", "tools": ctx.tools}],
|
||||
response = await client.push_trace(
|
||||
PushTracesRequest(
|
||||
messages=[ctx.trace],
|
||||
dataset=ctx.explorer_dataset,
|
||||
metadata=[{"source": "mcp", "tools": ctx.tools}],
|
||||
annotations=[deduplicated_annotations],
|
||||
)
|
||||
)
|
||||
ctx.trace_id = response.id[0]
|
||||
ctx.last_trace_length = len(ctx.trace)
|
||||
ctx.annotations.extend(deduplicated_annotations)
|
||||
else:
|
||||
ctx.trace.append(message)
|
||||
ctx.client.create_request_and_append_messages(
|
||||
trace_id=ctx.trace_id, messages=ctx.trace[ctx.last_trace_length :]
|
||||
response = await client.append_messages(
|
||||
AppendMessagesRequest(
|
||||
trace_id=ctx.trace_id,
|
||||
messages=ctx.trace[ctx.last_trace_length :],
|
||||
annotations=deduplicated_annotations,
|
||||
)
|
||||
)
|
||||
ctx.last_trace_length = len(ctx.trace)
|
||||
ctx.annotations.extend(deduplicated_annotations)
|
||||
except Exception as e:
|
||||
custom_print(ctx, "Error pushing trace:", e)
|
||||
mcp_log("[ERROR] Error pushing trace in append_and_push_trace:", e)
|
||||
|
||||
|
||||
def fetch_guardrails(ctx, dataset):
|
||||
"""Fetch guardrails from the Invariant Explorer."""
|
||||
# Use async fetch_guardrails_from_explorer in a thread
|
||||
return asyncio.run(
|
||||
fetch_guardrails_from_explorer(
|
||||
dataset, "Bearer " + os.getenv("INVARIANT_API_KEY")
|
||||
)
|
||||
def get_guardrails_check_result(
|
||||
ctx: McpContext,
|
||||
message: dict,
|
||||
action: GuardrailAction = GuardrailAction.BLOCK,
|
||||
) -> dict:
|
||||
"""
|
||||
Check against guardrails of type action.
|
||||
Works in both sync and async contexts by always using a dedicated thread.
|
||||
"""
|
||||
# Skip if no guardrails are configured for this action
|
||||
if not (
|
||||
(ctx.guardrails.blocking_guardrails and action == GuardrailAction.BLOCK)
|
||||
or (ctx.guardrails.logging_guardrails and action == GuardrailAction.LOG)
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=ctx.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=ctx.guardrails,
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
ctx.guardrails.blocking_guardrails
|
||||
if action == GuardrailAction.BLOCK
|
||||
else ctx.guardrails.logging_guardrails
|
||||
)
|
||||
|
||||
return run_task_sync(
|
||||
check_guardrails,
|
||||
messages=ctx.trace + [message],
|
||||
guardrails=guardrails_to_check,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
def check_blocking_guardrails(ctx, message, request):
|
||||
"""Check against blocking guardrails."""
|
||||
try:
|
||||
guardrails = fetch_guardrails(ctx, ctx.explorer_dataset)
|
||||
|
||||
custom_print(ctx, "Here are the guardrails: ", guardrails)
|
||||
|
||||
context = RequestContext.create(
|
||||
request_json=request,
|
||||
dataset_name=ctx.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=guardrails,
|
||||
)
|
||||
|
||||
if guardrails.blocking_guardrails:
|
||||
with redirect_stdout(ctx.log_out):
|
||||
return asyncio.run(
|
||||
check_guardrails(
|
||||
messages=ctx.trace + [message],
|
||||
guardrails=guardrails.blocking_guardrails,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
custom_print(ctx, "Error checking blocking guardrails:", e)
|
||||
|
||||
|
||||
def hook_tool_call(ctx, request):
|
||||
def hook_tool_call(ctx: McpContext, request: dict) -> tuple[dict, bool]:
|
||||
"""
|
||||
Hook function to intercept tool calls.
|
||||
Modify this function to change behavior for tool calls.
|
||||
Returns the potentially modified request.
|
||||
|
||||
If the request is blocked, it returns a tuple with a message explaining the block
|
||||
and a flag indicating the request was blocked.
|
||||
Otherwise it returns the original request and a flag indicating it was not blocked.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request.get('id')}",
|
||||
@@ -106,13 +188,39 @@ def hook_tool_call(ctx, request):
|
||||
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
|
||||
# Check for blocking guardrails
|
||||
result = check_blocking_guardrails(ctx, message, request)
|
||||
append_and_push_trace(ctx, message)
|
||||
return request
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
guardrailing_result = get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrailing_result
|
||||
and guardrailing_result.get("errors", [])
|
||||
and check_if_new_errors(ctx, guardrailing_result)
|
||||
):
|
||||
if ctx.push_explorer:
|
||||
run_task_in_background(
|
||||
append_and_push_trace, ctx, message, guardrailing_result
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrailing_result["errors"],
|
||||
},
|
||||
}, True
|
||||
|
||||
# Add the message to the trace
|
||||
ctx.trace.append(message)
|
||||
return request, False
|
||||
|
||||
|
||||
def hook_tool_result(ctx, result):
|
||||
def hook_tool_result(ctx: McpContext, result: dict) -> dict:
|
||||
"""
|
||||
Hook function to intercept tool results.
|
||||
Modify this function to change behavior for tool results.
|
||||
@@ -123,153 +231,126 @@ def hook_tool_result(ctx, result):
|
||||
|
||||
if method is None:
|
||||
return result
|
||||
elif method == "tools/call":
|
||||
elif method == MCP_TOOL_CALL:
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": {"result": result.get("result").get("content")},
|
||||
"content": result.get("result").get("content"),
|
||||
"error": result.get("result").get("error"),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
|
||||
# Check for blocking guardrails
|
||||
guardrailing_result = check_blocking_guardrails(ctx, message, result)
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
guardrailing_result = get_guardrails_check_result(
|
||||
ctx, message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if guardrailing_result and guardrailing_result.get("errors", []):
|
||||
result["result"]["content"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "[Invariant] Your MCP tool call was blocked for security reasons. Do not attempt to circumvent this block, rather explain to the user based on the following output what went wrong: \n"
|
||||
+ json.dumps(guardrailing_result["errors"]),
|
||||
}
|
||||
]
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": result.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrailing_result["errors"],
|
||||
},
|
||||
}
|
||||
|
||||
append_and_push_trace(ctx, message)
|
||||
if ctx.push_explorer:
|
||||
# Run append_and_push_trace in background
|
||||
run_task_in_background(
|
||||
append_and_push_trace, ctx, message, guardrailing_result
|
||||
)
|
||||
return result
|
||||
elif method == "tools/list":
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
ctx.tools = result.get("result").get("tools")
|
||||
return result
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def forward_stdout(process, ctx, buffer_size=1):
|
||||
"""Read from the process stdout, parse JSON chunks, and forward to sys.stdout"""
|
||||
buffer = b""
|
||||
|
||||
while True:
|
||||
chunk = process.stdout.read(buffer_size)
|
||||
if not chunk:
|
||||
break
|
||||
buffer += chunk
|
||||
|
||||
def stream_and_forward_stdout(mcp_process: subprocess.Popen, ctx: McpContext) -> None:
|
||||
"""Read from the mcp_process stdout, apply guardrails and and forward to sys.stdout"""
|
||||
for line in iter(mcp_process.stdout.readline, b""):
|
||||
try:
|
||||
# Try parsing full JSON object from buffer
|
||||
text = buffer.decode("utf-8")
|
||||
obj = json.loads(text)
|
||||
# Process complete JSON lines
|
||||
line_str = line.decode(UTF_8_ENCODING).strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
obj = hook_tool_result(ctx, obj)
|
||||
# clear the buffer
|
||||
buffer = b""
|
||||
parsed_json = json.loads(line_str)
|
||||
processed_json = hook_tool_result(ctx, parsed_json)
|
||||
|
||||
# Forward the original JSON to stdout
|
||||
json_output = json.dumps(obj).encode("utf-8") + b"\n"
|
||||
sys.stdout.buffer.write(json_output)
|
||||
# Write and flush immediately
|
||||
sys.stdout.buffer.write(write_as_utf8_bytes(processed_json))
|
||||
sys.stdout.buffer.flush()
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# Wait for more data
|
||||
continue
|
||||
except json.JSONDecodeError as je:
|
||||
mcp_log(f"[ERROR] JSON decode error in stdout processing: {str(je)}")
|
||||
mcp_log(f"[ERROR] Problematic line: {line[:200]}...")
|
||||
|
||||
except Exception as e:
|
||||
mcp_log(f"[ERROR] Error in stream_and_forward_stdout: {str(e)}")
|
||||
if line:
|
||||
mcp_log(f"[ERROR] Problematic line causing error: {line[:200]}...")
|
||||
|
||||
|
||||
def forward_stderr(process, ctx, buffer_size=1):
|
||||
"""Read from the process stderr and write to sys.stderr"""
|
||||
for line in iter(lambda: process.stderr.read(buffer_size), b""):
|
||||
ctx.log_out.buffer.write(line)
|
||||
ctx.log_out.buffer.flush()
|
||||
def stream_and_forward_stderr(
|
||||
mcp_process: subprocess.Popen, ctx: McpContext, read_chunk_size: int = 1
|
||||
) -> None:
|
||||
"""Read from the mcp_process stderr and write to sys.stderr"""
|
||||
for line in iter(lambda: mcp_process.stderr.read(read_chunk_size), b""):
|
||||
MCP_LOG_FILE.buffer.write(line)
|
||||
MCP_LOG_FILE.buffer.flush()
|
||||
|
||||
|
||||
def execute(args=None):
|
||||
"""Main function to execute the MCP gateway."""
|
||||
if "INVARIANT_API_KEY" not in os.environ:
|
||||
print("[ERROR] INVARIANT_API_KEY environment variable is not set.")
|
||||
sys.exit(1)
|
||||
|
||||
# Split args at the "--exec" boundary
|
||||
if args and "--exec" in args:
|
||||
exec_index = args.index("--exec")
|
||||
pre_exec_args = args[:exec_index]
|
||||
post_exec_args = args[exec_index + 1 :]
|
||||
else:
|
||||
pre_exec_args = args or []
|
||||
post_exec_args = []
|
||||
|
||||
if not post_exec_args:
|
||||
print("[ERROR] No command provided after --exec.")
|
||||
sys.exit(1)
|
||||
|
||||
# Parse pre-exec args using argparse
|
||||
parser = argparse.ArgumentParser(description="MCP Gateway")
|
||||
parser.add_argument("--directory", help="Working directory")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
||||
config = parser.parse_args(pre_exec_args)
|
||||
# Initialize the singleton context using config
|
||||
ctx = McpContext()
|
||||
|
||||
# Can now use post_exec_args as your cmd
|
||||
cmd = post_exec_args
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=0, # No buffering
|
||||
)
|
||||
|
||||
# Start threads to forward stdout and stderr
|
||||
stdout_thread = threading.Thread(
|
||||
target=forward_stdout, args=(process, ctx), daemon=True
|
||||
)
|
||||
stderr_thread = threading.Thread(
|
||||
target=forward_stderr, args=(process, ctx), daemon=True
|
||||
)
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
|
||||
# Handle forwarding stdin and intercept tool calls
|
||||
def run_stdio_input_loop(ctx: McpContext, mcp_process: subprocess.Popen) -> None:
|
||||
"""Handle standard input, intercept call and forward requests to mcp_process stdin."""
|
||||
try:
|
||||
current_chunk = b""
|
||||
|
||||
while True:
|
||||
data = sys.stdin.buffer.read(1)
|
||||
current_chunk += data
|
||||
buffer_input = sys.stdin.buffer.read(1)
|
||||
current_chunk += buffer_input
|
||||
|
||||
if not data:
|
||||
if not buffer_input:
|
||||
break
|
||||
|
||||
# Try to decode and parse as JSON to check for tool calls
|
||||
try:
|
||||
text = current_chunk.decode("utf-8")
|
||||
obj = json.loads(text)
|
||||
# clear the current chunk
|
||||
text = current_chunk.decode(UTF_8_ENCODING)
|
||||
parsed_json = json.loads(text)
|
||||
# clear the current chunk after successful parse
|
||||
current_chunk = b""
|
||||
# Refresh guardrails
|
||||
run_task_sync(ctx.load_guardrails)
|
||||
|
||||
if obj.get("method") is not None:
|
||||
ctx.id_to_method_mapping[obj.get("id")] = obj.get("method")
|
||||
if parsed_json.get(MCP_METHOD) is not None:
|
||||
ctx.id_to_method_mapping[parsed_json.get("id")] = parsed_json.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
|
||||
# Check if this is a tool call request
|
||||
if obj.get("method") == "tools/call":
|
||||
# Intercept and potentially modify the request
|
||||
obj = hook_tool_call(ctx, obj)
|
||||
# Convert back to bytes
|
||||
data = json.dumps(obj).encode("utf-8")
|
||||
|
||||
# Forward to the process
|
||||
process.stdin.write(data + b"\n")
|
||||
process.stdin.flush()
|
||||
if parsed_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Intercept and potentially block modify the request
|
||||
hook_tool_call_result, is_blocked = hook_tool_call(ctx, parsed_json)
|
||||
if not is_blocked:
|
||||
# If blocked, hook_tool_call_result contains the original request.
|
||||
# Forward the request to the MCP process.
|
||||
# It will handle the request and return a response.
|
||||
mcp_process.stdin.write(
|
||||
write_as_utf8_bytes(hook_tool_call_result)
|
||||
)
|
||||
mcp_process.stdin.flush()
|
||||
else:
|
||||
# If blocked, hook_tool_call_result contains the block message.
|
||||
# Forward the block message result back to the caller.
|
||||
# The original request is not passed to the MCP process.
|
||||
sys.stdout.buffer.write(
|
||||
write_as_utf8_bytes(hook_tool_call_result)
|
||||
)
|
||||
sys.stdout.buffer.flush()
|
||||
continue
|
||||
else:
|
||||
process.stdin.write(json.dumps(obj).encode("utf-8") + b"\n")
|
||||
process.stdin.flush()
|
||||
mcp_process.stdin.write(write_as_utf8_bytes(parsed_json))
|
||||
mcp_process.stdin.flush()
|
||||
continue
|
||||
except Exception:
|
||||
# Not a complete or valid JSON, just pass through
|
||||
@@ -278,25 +359,69 @@ def execute(args=None):
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
process.terminate()
|
||||
mcp_process.terminate()
|
||||
|
||||
|
||||
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Splits CLI arguments into two parts:
|
||||
1. Arguments intended for the MCP gateway (everything before `--exec`)
|
||||
2. Arguments for the underlying MCP server (everything after `--exec`)
|
||||
|
||||
Parameters:
|
||||
args (list[str]): The list of CLI arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[list[str], list[str]]: A tuple containing (mcp_gateway_args, mcp_server_command_args)
|
||||
"""
|
||||
if not args:
|
||||
mcp_log("[ERROR] No arguments provided.")
|
||||
sys.exit(1)
|
||||
|
||||
# Wait for process to terminate
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait()
|
||||
exec_index = args.index("--exec")
|
||||
except ValueError:
|
||||
mcp_log("[ERROR] '--exec' flag not found in arguments.")
|
||||
sys.exit(1)
|
||||
|
||||
mcp_gateway_args = args[:exec_index]
|
||||
mcp_server_command_args = args[exec_index + 1 :]
|
||||
|
||||
if not mcp_server_command_args:
|
||||
mcp_log("[ERROR] No arguments provided after '--exec'.")
|
||||
sys.exit(1)
|
||||
|
||||
return mcp_gateway_args, mcp_server_command_args
|
||||
|
||||
|
||||
# Handle signals to ensure clean shutdown
|
||||
def signal_handler(sig, frame):
|
||||
"""Handle signals for graceful shutdown."""
|
||||
ctx = McpContext()
|
||||
custom_print(ctx, f"Received signal {sig}, shutting down...")
|
||||
sys.exit(0)
|
||||
async def execute(args: list[str] = None):
|
||||
"""Main function to execute the MCP gateway."""
|
||||
if "INVARIANT_API_KEY" not in os.environ:
|
||||
mcp_log("[ERROR] INVARIANT_API_KEY environment variable is not set.")
|
||||
sys.exit(1)
|
||||
|
||||
mcp_gateway_args, mcp_server_command_args = split_args(args)
|
||||
ctx = McpContext(mcp_gateway_args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
execute(sys.argv)
|
||||
mcp_process = subprocess.Popen(
|
||||
mcp_server_command_args,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
bufsize=0,
|
||||
)
|
||||
|
||||
# Start threads to forward stdout and stderr
|
||||
threading.Thread(
|
||||
target=stream_and_forward_stdout,
|
||||
args=(mcp_process, ctx),
|
||||
daemon=True,
|
||||
).start()
|
||||
threading.Thread(
|
||||
target=stream_and_forward_stderr,
|
||||
args=(mcp_process, ctx),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
# Handle forwarding stdin and intercept tool calls
|
||||
run_stdio_input_loop(ctx, mcp_process)
|
||||
|
||||
+41
-28
@@ -1,9 +1,13 @@
|
||||
"""Context manager for MCP (Model Context Protocol) gateway."""
|
||||
|
||||
import atexit
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from invariant_sdk.client import Client
|
||||
import random
|
||||
|
||||
from gateway.integrations.explorer import (
|
||||
fetch_guardrails_from_explorer,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailRuleSet
|
||||
|
||||
|
||||
class McpContext:
|
||||
@@ -11,44 +15,53 @@ class McpContext:
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Control instance creation to ensure only one instance exists."""
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(McpContext, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the singleton instance with default values (only once)."""
|
||||
# Define _initialized attribute explicitly at the beginning to avoid warnings
|
||||
# This is redundant but prevents warnings about accessing before definition
|
||||
def __init__(self, cli_args: list):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._initialized = False
|
||||
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
def setup_logging(self):
|
||||
"""Set up logging to a file in the user's home directory.
|
||||
|
||||
Uses proper resource management to ensure the file is closed on program exit.
|
||||
"""
|
||||
os.makedirs(
|
||||
os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True
|
||||
)
|
||||
log_path = os.path.join(os.path.expanduser("~"), ".invariant", "mcp.log")
|
||||
self.log_out = open(log_path, "a", buffering=1, encoding="utf-8")
|
||||
atexit.register(self.log_out.close)
|
||||
sys.stderr = self.log_out
|
||||
|
||||
self.client = Client()
|
||||
self.explorer_dataset = "mcp-capture"
|
||||
config = self._parse_cli_args(cli_args)
|
||||
self.explorer_dataset = config.dataset_name
|
||||
self.push_explorer = config.push_explorer
|
||||
self.trace = []
|
||||
self.tools = []
|
||||
self.guardrails = GuardrailRuleSet(
|
||||
blocking_guardrails=[], logging_guardrails=[]
|
||||
)
|
||||
# We send the same trace messages for guardrails analysis multiple times.
|
||||
# We need to deduplicate them before sending to the explorer.
|
||||
self.annotations = []
|
||||
self.trace_id = None
|
||||
self.last_trace_length = 0
|
||||
self.guardrails = None
|
||||
self.id_to_method_mapping = {}
|
||||
setup_logging(self)
|
||||
# Mark as initialized
|
||||
self._initialized = True
|
||||
|
||||
def _parse_cli_args(self, cli_args: list) -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="MCP Gateway")
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
help="Name of the dataset where we want to push the MCP traces",
|
||||
type=str,
|
||||
default=f"mcp-capture-{random.randint(1, 100)}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-explorer",
|
||||
help="Enable pushing traces to Invariant Explorer",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
return parser.parse_args(cli_args)
|
||||
|
||||
async def load_guardrails(self):
|
||||
"""Run async setup logic (e.g. fetching guardrails)."""
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.explorer_dataset, "Bearer " + os.getenv("INVARIANT_API_KEY")
|
||||
)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Task utilities for running async functions"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any
|
||||
|
||||
from gateway.mcp.log import MCP_LOG_FILE, mcp_log
|
||||
|
||||
|
||||
def run_task_in_background(async_func, *args, **kwargs):
|
||||
"""
|
||||
Runs an async function in a background thread with its own event loop.
|
||||
This function does NOT block the calling thread as it immediately returns
|
||||
after starting the background thread.
|
||||
|
||||
Args:
|
||||
async_func: The async function to run
|
||||
*args: Positional arguments to pass to the async function
|
||||
**kwargs: Keyword arguments to pass to the async function
|
||||
"""
|
||||
|
||||
def thread_target():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(async_func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
mcp_log(
|
||||
f"[ERROR] Error in async thread while running run_task_in_background: {e}"
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Create and start a daemon thread
|
||||
thread = threading.Thread(target=thread_target, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def run_task_sync(async_func, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Runs an asynchronous function synchronously in a separate
|
||||
thread with its own event loop. This function blocks the calling
|
||||
thread until completion or timeout (10 seconds).
|
||||
|
||||
Args:
|
||||
async_func: The async function to run
|
||||
*args: Positional arguments to pass to the async function
|
||||
**kwargs: Keyword arguments to pass to the async function
|
||||
|
||||
Returns:
|
||||
Any: The return value of the async function
|
||||
"""
|
||||
|
||||
def run_in_new_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
async_func(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with redirect_stdout(MCP_LOG_FILE):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=10.0)
|
||||
@@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers
|
||||
from gateway.common.config_manager import (
|
||||
GatewayConfig,
|
||||
GatewayConfigManager,
|
||||
GuardrailsInHeader,
|
||||
extract_guardrails_from_header,
|
||||
)
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
@@ -34,7 +34,6 @@ from gateway.integrations.guardrails import (
|
||||
InstrumentedStreamingResponse,
|
||||
Replacement,
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
@@ -70,7 +69,7 @@ async def anthropic_v1_messages_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
):
|
||||
"""Proxy calls to the Anthropic APIs"""
|
||||
headers = {
|
||||
@@ -171,7 +170,7 @@ async def push_to_explorer(
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_execution_result.get("errors", []), action="block"
|
||||
guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
@@ -181,7 +180,7 @@ async def push_to_explorer(
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
logging_guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
@@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers
|
||||
from gateway.common.config_manager import (
|
||||
GatewayConfig,
|
||||
GatewayConfigManager,
|
||||
GuardrailsInHeader,
|
||||
extract_guardrails_from_header,
|
||||
)
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
@@ -31,7 +31,6 @@ from gateway.integrations.guardrails import (
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
Replacement,
|
||||
preload_guardrails,
|
||||
check_guardrails,
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ async def gemini_generate_content_gateway(
|
||||
None, title="Response Format", description="Set to 'sse' for streaming"
|
||||
),
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
) -> Response:
|
||||
"""Proxy calls to the Gemini GenerateContent API"""
|
||||
if endpoint not in ["generateContent", "streamGenerateContent"]:
|
||||
@@ -413,7 +412,7 @@ async def push_to_explorer(
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_execution_result.get("errors", []), action="block"
|
||||
guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
@@ -423,7 +422,7 @@ async def push_to_explorer(
|
||||
response_json=response_json,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
logging_guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
@@ -12,7 +12,7 @@ from gateway.common.authorization import extract_authorization_from_headers
|
||||
from gateway.common.config_manager import (
|
||||
GatewayConfig,
|
||||
GatewayConfigManager,
|
||||
GuardrailsInHeader,
|
||||
extract_guardrails_from_header,
|
||||
)
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
@@ -30,7 +30,6 @@ from gateway.integrations.guardrails import (
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
check_guardrails,
|
||||
preload_guardrails,
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
@@ -113,7 +112,7 @@ async def openai_chat_completions_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(GuardrailsInHeader),
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
) -> Response:
|
||||
"""Proxy calls to the OpenAI APIs"""
|
||||
headers = {
|
||||
@@ -491,9 +490,7 @@ async def push_to_explorer(
|
||||
# or if the guardrails check returned errors.
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
guardrails_errors = guardrails_execution_result.get("errors", [])
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_errors, action="block"
|
||||
)
|
||||
annotations = create_annotations_from_guardrails_errors(guardrails_errors)
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
@@ -501,7 +498,7 @@ async def push_to_explorer(
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
logging_guardrails_execution_result.get("errors", [])
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
Reference in New Issue
Block a user