diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py index 3cf2ef0..fa53a5a 100644 --- a/gateway/common/mcp_sessions_manager.py +++ b/gateway/common/mcp_sessions_manager.py @@ -2,8 +2,10 @@ import asyncio import contextlib +import getpass import os import random +import socket from typing import Any, Dict, List, Optional @@ -24,6 +26,14 @@ from gateway.integrations.guardrails import check_guardrails DEFAULT_API_URL = "https://explorer.invariantlabs.ai" +def user_and_host() -> str: + """Get the current user and hostname.""" + username = getpass.getuser() + hostname = socket.gethostname() + + return f"{username}@{hostname}" + + class McpSession(BaseModel): """ Represents a single MCP session. @@ -81,6 +91,14 @@ class McpSession(BaseModel): async with self._lock: yield + def session_metadata(self) -> dict: + """Generate metadata for the current session.""" + return { + "session_id": self.session_id, + "system_user": user_and_host(), + **(self.metadata or {}), + } + async def get_guardrails_check_result( self, message: dict, @@ -102,6 +120,10 @@ class McpSession(BaseModel): dataset_name=self.explorer_dataset, invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), guardrails=self.guardrails, + guardrails_parameters={ + "metadata": self.session_metadata(), + "action": action, + }, ) guardrails_to_check = ( @@ -170,12 +192,10 @@ class McpSession(BaseModel): # If no trace exists, create a new one if not self.trace_id: - # pylint: disable=no-member - metadata = {"source": "mcp", "tools": self.metadata.get("tools", [])} - if self.metadata.get("mcp_client_name"): - metadata["mcp_client"] = self.metadata.get("mcp_client_name") - if self.metadata.get("mcp_server_name"): - metadata["mcp_server"] = self.metadata.get("mcp_server_name") + # default metadata + metadata = {"source": "mcp"} + # include MCP session metadata + metadata.update(self.session_metadata()) response = await client.push_trace( PushTracesRequest( diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index 12e56fd..201c37b 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -89,7 +89,7 @@ async def mcp_post_gateway( if request_json.get(MCP_PARAMS) and request_json.get(MCP_PARAMS).get( MCP_CLIENT_INFO ): - session.metadata["mcp_client_name"] = ( + session.metadata["mcp_client"] = ( request_json.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") ) @@ -446,10 +446,6 @@ def _convert_localhost_to_docker_host(mcp_server_base_url: str) -> str: Returns: str: Modified server address with localhost references changed to host.docker.internal """ - # check if we are running in a docker container - if not os.environ.get("DOCKER_ENV"): - return mcp_server_base_url - if "localhost" in mcp_server_base_url or "127.0.0.1" in mcp_server_base_url: # Replace localhost or 127.0.0.1 with host.docker.internal modified_address = re.sub( @@ -510,7 +506,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get( MCP_SERVER_INFO ): - session.metadata["mcp_server_name"] = ( + session.metadata["mcp_server"] = ( response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") ) diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index 758ac38..a4c6809 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -174,6 +174,8 @@ async def test_mcp_with_gateway_and_logging_guardrails( and metadata["mcp_client"] == "mcp" and metadata["mcp_server"] == "messenger_server" ) + assert "session_id" in metadata + assert "system_user" in metadata assert trace["messages"][2]["role"] == "assistant" assert trace["messages"][2]["tool_calls"][0]["function"] == { "name": "get_last_message_from_user", @@ -287,6 +289,8 @@ async def test_mcp_with_gateway_and_blocking_guardrails( and metadata["mcp_client"] == "mcp" and metadata["mcp_server"] == "messenger_server" ) + assert "session_id" in metadata + assert "system_user" in metadata assert trace["messages"][2]["role"] == "assistant" assert trace["messages"][2]["tool_calls"][0]["function"] == { "name": "get_last_message_from_user", @@ -388,6 +392,8 @@ async def test_mcp_with_gateway_hybrid_guardrails( and metadata["mcp_client"] == "mcp" and metadata["mcp_server"] == "messenger_server" ) + assert "session_id" in metadata + assert "system_user" in metadata assert trace["messages"][2]["role"] == "assistant" assert trace["messages"][2]["tool_calls"][0]["function"] == { "name": "get_last_message_from_user", @@ -475,4 +481,4 @@ async def test_mcp_tool_list_blocking( assert "blocked_get_last_message_from_user" in str(tools_result), ( "Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: " + str(tools_result) - ) \ No newline at end of file + )