Add logging guardrails based test for MCP gateway.

This commit is contained in:
Hemang
2025-04-29 11:57:25 +05:30
committed by Hemang Sarkar
parent d983b7431b
commit d877e5a1e6
+106
View File
@@ -1,10 +1,12 @@
"""Test MCP gateway via stdio."""
import os
import uuid
import pytest
import requests
from utils import create_dataset, add_guardrail_to_dataset
from resources.mcp.client.main import run as mcp_client_run
@@ -64,3 +66,107 @@ async def test_mcp_stdio_with_gateway(
assert trace["messages"][1]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
@pytest.mark.asyncio
async def test_mcp_stdio_with_gateway_and_logging_guardrails(
explorer_api_url, invariant_gateway_package_whl_file
):
"""Test MCP gateway via stdio and verify that logging guardrails work"""
project_name = "test-mcp-" + str(uuid.uuid4())
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=project_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "food in ToolOutput" if:\n (tool_output: ToolOutput)\n (chunk: str) in text(tool_output.content)\n "food" in chunk',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "get_last_message_from_user is called" if:\n (tool_call: ToolCall)\n tool_call is tool:get_last_message_from_user',
action="log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
# Run the MCP client and get the result
result = await mcp_client_run(
invariant_gateway_package_whl_file,
project_name,
server_script_path="resources/mcp/messenger_server/main.py",
push_to_explorer=True,
tool_name="get_last_message_from_user",
tool_args={"username": "Alice"},
)
assert result.isError is False
assert (
result.content[0].type == "text"
and result.content[0].text == "What is your favorite food?\n"
)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{project_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
metadata = trace["extra_metadata"]
assert (
metadata["source"] == "mcp"
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert trace["messages"][0]["role"] == "assistant"
assert trace["messages"][0]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
"arguments": {"username": "Alice"},
}
assert trace["messages"][1]["role"] == "tool"
assert trace["messages"][1]["content"] == [
{"type": "text", "text": "What is your favorite food?\n"}
]
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
food_annotation = None
tool_call_annotation = None
assert len(annotations) == 2
for annotation in annotations:
if (
annotation["content"] == "food in ToolOutput"
and annotation["address"] == "messages.1.content.0.text:22-26"
):
food_annotation = annotation
elif (
annotation["content"] == "get_last_message_from_user is called"
and annotation["address"] == "messages.0.tool_calls.0"
):
tool_call_annotation = annotation
assert food_annotation is not None, "Missing 'food in ToolOutput' annotation"
assert (
tool_call_annotation is not None
), "Missing 'get_last_message_from_user is called' annotation"
assert food_annotation["extra_metadata"]["source"] == "guardrails-error"
assert tool_call_annotation["extra_metadata"]["source"] == "guardrails-error"