mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Move hook_tool_call and hook_tool_call_response to mcp_utils.py so that it can be used by both SSE and Streamable implementations.
This commit is contained in:
@@ -1,10 +1,24 @@
|
||||
"""MCP utility functions."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from typing import Tuple
|
||||
|
||||
from gateway.common.constants import MCP_SERVER_BASE_URL_HEADER
|
||||
from fastapi import Request, HTTPException
|
||||
from gateway.common.constants import (
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
MCP_SERVER_BASE_URL_HEADER,
|
||||
MCP_PARAMS,
|
||||
MCP_RESULT,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.mcp.log import format_errors_in_response
|
||||
|
||||
|
||||
def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
@@ -49,3 +63,150 @@ def get_mcp_server_base_url(request: Request) -> str:
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
return _convert_localhost_to_docker_host(mcp_server_base_url).rstrip("/")
|
||||
|
||||
|
||||
async def hook_tool_call(
|
||||
session_id: str, session_store: McpSessionsManager, request_body: dict
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Hook to process the request JSON before sending it to the MCP server.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
request_body (dict): The request JSON to be processed.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request_body.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request_body.get(MCP_PARAMS).get("name"),
|
||||
"arguments": request_body.get(MCP_PARAMS).get("arguments"),
|
||||
},
|
||||
}
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and check_if_new_errors(session_id, session_store, guardrails_result)
|
||||
):
|
||||
# Add the trace to the explorer
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
guardrails_result=guardrails_result,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}, True
|
||||
# Push trace to the explorer
|
||||
await session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
return request_body, False
|
||||
|
||||
|
||||
def check_if_new_errors(
|
||||
session_id: str, session_store: McpSessionsManager, guardrails_result: dict
|
||||
) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
session = session_store.get_session(session_id)
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def hook_tool_call_response(
|
||||
session_id: str,
|
||||
session_store: McpSessionsManager,
|
||||
response_json: dict,
|
||||
is_tools_list=False,
|
||||
) -> dict:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
response_json (dict): The response JSON to be processed.
|
||||
Returns:
|
||||
dict: The response JSON is returned if no guardrail is violated
|
||||
else an error dict is returned.
|
||||
"""
|
||||
blocked = False
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
"content": response_json.get(MCP_RESULT).get("content"),
|
||||
"error": response_json.get(MCP_RESULT).get("error"),
|
||||
}
|
||||
result = response_json
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and check_if_new_errors(session_id, session_store, guardrails_result)
|
||||
):
|
||||
blocked = True
|
||||
# If the request is blocked, return a message indicating the block reason
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
else:
|
||||
# special error response for tools/list tool call
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "blocked_" + tool["name"],
|
||||
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
|
||||
% format_errors_in_response(guardrails_result["errors"]),
|
||||
"inputSchema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "invariant_mcp_server_blockedArguments",
|
||||
"type": "object",
|
||||
},
|
||||
"annotations": {
|
||||
"title": "This tool was blocked by security guardrails.",
|
||||
},
|
||||
}
|
||||
for tool in response_json["result"]["tools"]
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result, blocked
|
||||
|
||||
@@ -12,8 +12,6 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
MCP_METHOD,
|
||||
MCP_TOOL_CALL,
|
||||
MCP_LIST_TOOLS,
|
||||
@@ -23,14 +21,15 @@ from gateway.common.constants import (
|
||||
MCP_CLIENT_INFO,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import get_mcp_server_base_url
|
||||
from gateway.mcp.log import format_errors_in_response
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.common.mcp_utils import (
|
||||
get_mcp_server_base_url,
|
||||
hook_tool_call,
|
||||
hook_tool_call_response,
|
||||
)
|
||||
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
"connection",
|
||||
@@ -71,35 +70,38 @@ async def mcp_post_sse_gateway(
|
||||
get_mcp_server_base_url(request) + "/messages/?" + session_id
|
||||
)
|
||||
request_body_bytes = await request.body()
|
||||
request_json = json.loads(request_body_bytes)
|
||||
request_body = json.loads(request_body_bytes)
|
||||
session = session_store.get_session(session_id)
|
||||
if request_json.get(MCP_METHOD) and request_json.get("id"):
|
||||
session.id_to_method_mapping[request_json.get("id")] = request_json.get(
|
||||
if request_body.get(MCP_METHOD) and request_body.get("id"):
|
||||
session.id_to_method_mapping[request_body.get("id")] = request_body.get(
|
||||
MCP_METHOD
|
||||
)
|
||||
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
|
||||
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.metadata["mcp_client"] = (
|
||||
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
if request_json.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
# Intercept and potentially block the request
|
||||
hook_tool_call_result, is_blocked = await _hook_tool_call(
|
||||
session_id=session_id, request_json=request_json
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
request_body=request_body,
|
||||
)
|
||||
if is_blocked:
|
||||
# Add the error message to the session.
|
||||
# The error message is sent back to the client using the SSE stream.
|
||||
await session.add_pending_error_message(hook_tool_call_result)
|
||||
return Response(content="Accepted", status_code=202)
|
||||
elif request_json.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
# Intercept and potentially block the request
|
||||
hook_tool_call_result, is_blocked = await _hook_tool_call(
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
request_json={
|
||||
"id": request_json.get("id"),
|
||||
session_store=session_store,
|
||||
request_body={
|
||||
"id": request_body.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
"params": {"name": MCP_LIST_TOOLS, "arguments": {}},
|
||||
},
|
||||
@@ -119,7 +121,7 @@ async def mcp_post_sse_gateway(
|
||||
for k, v in request.headers.items()
|
||||
if k.lower() in MCP_SERVER_POST_HEADERS
|
||||
},
|
||||
json=request_json,
|
||||
json=request_body,
|
||||
params=query_params,
|
||||
)
|
||||
return Response(
|
||||
@@ -293,135 +295,6 @@ async def mcp_get_sse_gateway(
|
||||
)
|
||||
|
||||
|
||||
async def _hook_tool_call(session_id: str, request_json: dict) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Hook to process the request JSON before sending it to the MCP server.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
request_json (dict): The request JSON to be processed.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request_json.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request_json.get(MCP_PARAMS).get("name"),
|
||||
"arguments": request_json.get(MCP_PARAMS).get("arguments"),
|
||||
},
|
||||
}
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
# Add the trace to the explorer
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
guardrails_result=guardrails_result,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}, True
|
||||
# Push trace to the explorer
|
||||
await session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
return request_json, False
|
||||
|
||||
|
||||
async def _hook_tool_call_response(
|
||||
session_id: str, response_json: dict, is_tools_list=False
|
||||
) -> dict:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
response_json (dict): The response JSON to be processed.
|
||||
Returns:
|
||||
dict: The response JSON is returned if no guardrail is violated
|
||||
else an error dict is returned.
|
||||
"""
|
||||
blocked = False
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
"content": response_json.get(MCP_RESULT).get("content"),
|
||||
"error": response_json.get(MCP_RESULT).get("error"),
|
||||
}
|
||||
result = response_json
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
blocked = True
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
else:
|
||||
# special error response for tools/list tool call
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "blocked_" + tool["name"],
|
||||
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
|
||||
% format_errors_in_response(guardrails_result["errors"]),
|
||||
# no parameters
|
||||
"inputSchema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "invariant_mcp_server_blockedArguments",
|
||||
"type": "object",
|
||||
},
|
||||
"annotations": {
|
||||
"title": "This tool was blocked by security guardrails.",
|
||||
},
|
||||
}
|
||||
for tool in response_json["result"]["tools"]
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result, blocked
|
||||
|
||||
|
||||
async def _handle_endpoint_event(
|
||||
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
|
||||
) -> Tuple[bytes, str]:
|
||||
@@ -476,8 +349,9 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
if method == MCP_TOOL_CALL:
|
||||
hook_tool_call_response, blocked = await _hook_tool_call_response(
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json=response_json,
|
||||
)
|
||||
# Update the event bytes with hook_tool_call_response.
|
||||
@@ -485,8 +359,8 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
UTF_8
|
||||
event_bytes = (
|
||||
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
@@ -494,8 +368,9 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
hook_tool_call_response, blocked = await _hook_tool_call_response(
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json={
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
@@ -512,8 +387,8 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
# If guardrail is violated, it contains the error message.
|
||||
# pylint: disable=line-too-long
|
||||
if blocked:
|
||||
event_bytes = f"event: {sse.event}\ndata: {json.dumps(hook_tool_call_response)}\n\n".encode(
|
||||
UTF_8
|
||||
event_bytes = (
|
||||
f"event: {sse.event}\ndata: {json.dumps(result)}\n\n".encode(UTF_8)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -529,18 +404,6 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
return event_bytes
|
||||
|
||||
|
||||
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
session = session_store.get_session(session_id)
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _check_for_pending_error_messages(
|
||||
session_id: str, pending_error_messages_queue: asyncio.Queue
|
||||
):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
@@ -13,8 +12,6 @@ from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE,
|
||||
INVARIANT_SESSION_ID_PREFIX,
|
||||
MCP_CLIENT_INFO,
|
||||
MCP_LIST_TOOLS,
|
||||
@@ -25,14 +22,15 @@ from gateway.common.constants import (
|
||||
MCP_TOOL_CALL,
|
||||
UTF_8,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import get_mcp_server_base_url
|
||||
from gateway.integrations.explorer import create_annotations_from_guardrails_errors
|
||||
from gateway.mcp.log import format_errors_in_response
|
||||
from gateway.common.mcp_utils import (
|
||||
get_mcp_server_base_url,
|
||||
hook_tool_call,
|
||||
hook_tool_call_response,
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
session_store = McpSessionsManager()
|
||||
@@ -448,146 +446,6 @@ async def _handle_mcp_streaming_response(
|
||||
)
|
||||
|
||||
|
||||
def _check_if_new_errors(session_id: str, guardrails_result: dict) -> bool:
|
||||
"""Checks if there are new errors in the guardrails result."""
|
||||
session = session_store.get_session(session_id)
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_result.get("errors", [])
|
||||
)
|
||||
for annotation in annotations:
|
||||
if annotation not in session.annotations:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _hook_tool_call(session_id: str, request_body: dict) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Hook to process the request JSON before sending it to the MCP server.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
request_body (dict): The request JSON to be processed.
|
||||
"""
|
||||
tool_call = {
|
||||
"id": f"call_{request_body.get('id')}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": request_body.get(MCP_PARAMS).get("name"),
|
||||
"arguments": request_body.get(MCP_PARAMS).get("arguments"),
|
||||
},
|
||||
}
|
||||
message = {"role": "assistant", "content": "", "tool_calls": [tool_call]}
|
||||
# Check for blocking guardrails - this blocks until completion
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
# If the request is blocked, return a message indicating the block reason.
|
||||
# If there are new errors, run append_and_push_trace in background.
|
||||
# If there are no new errors, just return the original request.
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
# Add the trace to the explorer
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(
|
||||
session_id=session_id,
|
||||
message=message,
|
||||
guardrails_result=guardrails_result,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}, True
|
||||
# Push trace to the explorer
|
||||
await session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
return request_body, False
|
||||
|
||||
|
||||
async def _hook_tool_call_response(
|
||||
session_id: str, response_json: dict, is_tools_list=False
|
||||
) -> dict:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
Args:
|
||||
session_id (str): The session ID associated with the request.
|
||||
response_json (dict): The response JSON to be processed.
|
||||
Returns:
|
||||
dict: The response JSON is returned if no guardrail is violated
|
||||
else an error dict is returned.
|
||||
"""
|
||||
blocked = False
|
||||
message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": f"call_{response_json.get('id')}",
|
||||
"content": response_json.get(MCP_RESULT).get("content"),
|
||||
"error": response_json.get(MCP_RESULT).get("error"),
|
||||
}
|
||||
result = response_json
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
message, action=GuardrailAction.BLOCK
|
||||
)
|
||||
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and _check_if_new_errors(session_id, guardrails_result)
|
||||
):
|
||||
blocked = True
|
||||
# If the request is blocked, return a message indicating the block reason
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
else:
|
||||
# special error response for tools/list tool call
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "blocked_" + tool["name"],
|
||||
"description": INVARIANT_GUARDRAILS_BLOCKED_TOOLS_MESSAGE
|
||||
% format_errors_in_response(guardrails_result["errors"]),
|
||||
"inputSchema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "invariant_mcp_server_blockedArguments",
|
||||
"type": "object",
|
||||
},
|
||||
"annotations": {
|
||||
"title": "This tool was blocked by security guardrails.",
|
||||
},
|
||||
}
|
||||
for tool in response_json["result"]["tools"]
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Push trace to the explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
session_store.add_message_to_session(session_id, message, guardrails_result)
|
||||
)
|
||||
return result, blocked
|
||||
|
||||
|
||||
async def _intercept_request(session_id: str, request_body: dict) -> Response | None:
|
||||
"""
|
||||
Intercept the request and check for guardrails.
|
||||
@@ -595,8 +453,10 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response |
|
||||
If the request is blocked, it returns a message indicating the block reason.
|
||||
"""
|
||||
if request_body.get(MCP_METHOD) == MCP_TOOL_CALL:
|
||||
hook_tool_call_result, is_blocked = await _hook_tool_call(
|
||||
session_id=session_id, request_body=request_body
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
request_body=request_body,
|
||||
)
|
||||
if is_blocked:
|
||||
return Response(
|
||||
@@ -605,8 +465,9 @@ async def _intercept_request(session_id: str, request_body: dict) -> Response |
|
||||
media_type="application/json",
|
||||
)
|
||||
elif request_body.get(MCP_METHOD) == MCP_LIST_TOOLS:
|
||||
hook_tool_call_result, is_blocked = await _hook_tool_call(
|
||||
hook_tool_call_result, is_blocked = await hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
request_body={
|
||||
"id": request_body.get("id"),
|
||||
"method": MCP_LIST_TOOLS,
|
||||
@@ -636,10 +497,12 @@ async def _intercept_response(
|
||||
method = session.id_to_method_mapping.get(response_json.get("id"))
|
||||
# Intercept and potentially block tool call response
|
||||
if method == MCP_TOOL_CALL:
|
||||
hook_tool_call_response, blocked = await _hook_tool_call_response(
|
||||
session_id=session_id, response_json=response_json
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json=response_json,
|
||||
)
|
||||
return hook_tool_call_response, blocked
|
||||
return result, blocked
|
||||
# Intercept and potentially block list tool call response
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
@@ -647,8 +510,9 @@ async def _intercept_response(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
hook_tool_call_response, blocked = await _hook_tool_call_response(
|
||||
result, blocked = await hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_json={
|
||||
"id": response_json.get("id"),
|
||||
"result": {
|
||||
@@ -658,5 +522,5 @@ async def _intercept_response(
|
||||
},
|
||||
is_tools_list=True,
|
||||
)
|
||||
return hook_tool_call_response, blocked
|
||||
return result, blocked
|
||||
return response_json, False
|
||||
|
||||
Reference in New Issue
Block a user