Update metadata in MCP SSE similar to what we do in MCP stdio.

This commit is contained in:
Hemang
2025-05-21 16:15:14 +02:00
committed by Hemang Sarkar
parent 177d247a83
commit 03817b005c
3 changed files with 35 additions and 13 deletions

View File

@@ -2,8 +2,10 @@
import asyncio
import contextlib
import getpass
import os
import random
import socket
from typing import Any, Dict, List, Optional
@@ -24,6 +26,14 @@ from gateway.integrations.guardrails import check_guardrails
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
def user_and_host() -> str:
"""Get the current user and hostname."""
username = getpass.getuser()
hostname = socket.gethostname()
return f"{username}@{hostname}"
class McpSession(BaseModel):
"""
Represents a single MCP session.
@@ -81,6 +91,14 @@ class McpSession(BaseModel):
async with self._lock:
yield
def session_metadata(self) -> dict:
"""Generate metadata for the current session."""
return {
"session_id": self.session_id,
"system_user": user_and_host(),
**(self.metadata or {}),
}
async def get_guardrails_check_result(
self,
message: dict,
@@ -102,6 +120,10 @@ class McpSession(BaseModel):
dataset_name=self.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
guardrails=self.guardrails,
guardrails_parameters={
"metadata": self.session_metadata(),
"action": action,
},
)
guardrails_to_check = (
@@ -170,12 +192,10 @@ class McpSession(BaseModel):
# If no trace exists, create a new one
if not self.trace_id:
# pylint: disable=no-member
metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])}
if self.metadata.get("mcp_client_name"):
metadata["mcp_client"] = self.metadata.get("mcp_client_name")
if self.metadata.get("mcp_server_name"):
metadata["mcp_server"] = self.metadata.get("mcp_server_name")
# default metadata
metadata = {"source": "mcp"}
# include MCP session metadata
metadata.update(self.session_metadata())
response = await client.push_trace(
PushTracesRequest(

View File

@@ -89,7 +89,7 @@ async def mcp_post_gateway(
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.metadata["mcp_client_name"] = (
session.metadata["mcp_client"] = (
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
@@ -446,10 +446,6 @@ def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
Returns:
str: Modified server address with localhost references changed to host.docker.internal
"""
# check if we are running in a docker container
if not os.environ.get("DOCKER_ENV"):
return mcp_server_base_url
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
# Replace localhost or 127.0.0.1 with host.docker.internal
modified_address = re.sub(
@@ -510,7 +506,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.metadata["mcp_server_name"] = (
session.metadata["mcp_server"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)

View File

@@ -174,6 +174,8 @@ async def test_mcp_with_gateway_and_logging_guardrails(
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert "session_id" in metadata
assert "system_user" in metadata
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
@@ -287,6 +289,8 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert "session_id" in metadata
assert "system_user" in metadata
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
@@ -388,6 +392,8 @@ async def test_mcp_with_gateway_hybrid_guardrails(
and metadata["mcp_client"] == "mcp"
and metadata["mcp_server"] == "messenger_server"
)
assert "session_id" in metadata
assert "system_user" in metadata
assert trace["messages"][2]["role"] == "assistant"
assert trace["messages"][2]["tool_calls"][0]["function"] == {
"name": "get_last_message_from_user",
@@ -475,4 +481,4 @@ async def test_mcp_tool_list_blocking(
assert "blocked_get_last_message_from_user" in str(tools_result), (
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
+ str(tools_result)
)
)