mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Update metadata in MCP SSE similar to what we do in MCP stdio.
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import getpass
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -24,6 +26,14 @@ from gateway.integrations.guardrails import check_guardrails
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
|
||||
def user_and_host() -> str:
|
||||
"""Get the current user and hostname."""
|
||||
username = getpass.getuser()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
return f"{username}@{hostname}"
|
||||
|
||||
|
||||
class McpSession(BaseModel):
|
||||
"""
|
||||
Represents a single MCP session.
|
||||
@@ -81,6 +91,14 @@ class McpSession(BaseModel):
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
def session_metadata(self) -> dict:
|
||||
"""Generate metadata for the current session."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"system_user": user_and_host(),
|
||||
**(self.metadata or {}),
|
||||
}
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
self,
|
||||
message: dict,
|
||||
@@ -102,6 +120,10 @@ class McpSession(BaseModel):
|
||||
dataset_name=self.explorer_dataset,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
guardrails=self.guardrails,
|
||||
guardrails_parameters={
|
||||
"metadata": self.session_metadata(),
|
||||
"action": action,
|
||||
},
|
||||
)
|
||||
|
||||
guardrails_to_check = (
|
||||
@@ -170,12 +192,10 @@ class McpSession(BaseModel):
|
||||
|
||||
# If no trace exists, create a new one
|
||||
if not self.trace_id:
|
||||
# pylint: disable=no-member
|
||||
metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])}
|
||||
if self.metadata.get("mcp_client_name"):
|
||||
metadata["mcp_client"] = self.metadata.get("mcp_client_name")
|
||||
if self.metadata.get("mcp_server_name"):
|
||||
metadata["mcp_server"] = self.metadata.get("mcp_server_name")
|
||||
# default metadata
|
||||
metadata = {"source": "mcp"}
|
||||
# include MCP session metadata
|
||||
metadata.update(self.session_metadata())
|
||||
|
||||
response = await client.push_trace(
|
||||
PushTracesRequest(
|
||||
|
||||
@@ -89,7 +89,7 @@ async def mcp_post_gateway(
|
||||
if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.metadata["mcp_client_name"] = (
|
||||
session.metadata["mcp_client"] = (
|
||||
request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
@@ -446,10 +446,6 @@ def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str:
|
||||
Returns:
|
||||
str: Modified server address with localhost references changed to host.docker.internal
|
||||
"""
|
||||
# check if we are running in a docker container
|
||||
if not os.environ.get("DOCKER_ENV"):
|
||||
return mcp_server_base_url
|
||||
|
||||
if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url:
|
||||
# Replace localhost or 127.0.0.1 with host.docker.internal
|
||||
modified_address = re.sub(
|
||||
@@ -510,7 +506,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.metadata["mcp_server_name"] = (
|
||||
session.metadata["mcp_server"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
|
||||
|
||||
@@ -174,6 +174,8 @@ async def test_mcp_with_gateway_and_logging_guardrails(
|
||||
and metadata["mcp_client"] == "mcp"
|
||||
and metadata["mcp_server"] == "messenger_server"
|
||||
)
|
||||
assert "session_id" in metadata
|
||||
assert "system_user" in metadata
|
||||
assert trace["messages"][2]["role"] == "assistant"
|
||||
assert trace["messages"][2]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
@@ -287,6 +289,8 @@ async def test_mcp_with_gateway_and_blocking_guardrails(
|
||||
and metadata["mcp_client"] == "mcp"
|
||||
and metadata["mcp_server"] == "messenger_server"
|
||||
)
|
||||
assert "session_id" in metadata
|
||||
assert "system_user" in metadata
|
||||
assert trace["messages"][2]["role"] == "assistant"
|
||||
assert trace["messages"][2]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
@@ -388,6 +392,8 @@ async def test_mcp_with_gateway_hybrid_guardrails(
|
||||
and metadata["mcp_client"] == "mcp"
|
||||
and metadata["mcp_server"] == "messenger_server"
|
||||
)
|
||||
assert "session_id" in metadata
|
||||
assert "system_user" in metadata
|
||||
assert trace["messages"][2]["role"] == "assistant"
|
||||
assert trace["messages"][2]["tool_calls"][0]["function"] == {
|
||||
"name": "get_last_message_from_user",
|
||||
@@ -475,4 +481,4 @@ async def test_mcp_tool_list_blocking(
|
||||
assert "blocked_get_last_message_from_user" in str(tools_result), (
|
||||
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
|
||||
+ str(tools_result)
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user