diff --git a/gateway/mcp/mcp_sessions_manager.py b/gateway/mcp/mcp_sessions_manager.py index bb915d2..9add397 100644 --- a/gateway/mcp/mcp_sessions_manager.py +++ b/gateway/mcp/mcp_sessions_manager.py @@ -23,7 +23,6 @@ from gateway.integrations.explorer import ( fetch_guardrails_from_explorer, ) from gateway.integrations.guardrails import check_guardrails -from gateway.mcp.constants import INVARIANT_SESSION_ID_PREFIX def user_and_host() -> str: @@ -110,9 +109,6 @@ class McpSession(BaseModel): "system_user": user_and_host(), **(self.attributes.metadata or {}), } - metadata["is_stateless_http_server"] = self.session_id.startswith( - INVARIANT_SESSION_ID_PREFIX - ) return metadata async def get_guardrails_check_result( diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index 6367460..a0118c3 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -52,7 +52,7 @@ class McpTransportBase(ABC): """ # Update session with request information session = self.session_store.get_session(session_id) - McpTransportBase.update_session_from_request(session, request_data) + self.update_session_from_request(session, request_data) # Refresh guardrails await session.load_guardrails() @@ -74,7 +74,7 @@ class McpTransportBase(ABC): """ # Update session with server information session = self.session_store.get_session(session_id) - McpTransportBase.update_mcp_server_in_session_metadata(session, response_data) + self.update_mcp_server_in_session_metadata(session, response_data) # Intercept and apply guardrails to response return await McpTransportBase.intercept_response( @@ -106,11 +106,11 @@ class McpTransportBase(ABC): interception_result = request_data is_blocked = False if method == MCP_TOOL_CALL: - interception_result, is_blocked = await McpTransportBase.hook_tool_call( + interception_result, is_blocked = await self.hook_tool_call( session_id, self.session_store, request_data ) elif method == MCP_LIST_TOOLS: - interception_result, is_blocked = await McpTransportBase.hook_tool_call( + interception_result, is_blocked = await self.hook_tool_call( session_id=session_id, session_store=self.session_store, request_body={ diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index b066fb7..bf5f694 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -27,7 +27,7 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" CONTENT_TYPE_HEADER = "content-type" MCP_SESSION_ID_HEADER = "mcp-session-id" -MCP_SERVER_POST_DELETE_HEADERS = { +MCP_SERVER_POST_AND_DELETE_HEADERS = { "connection", "accept", CONTENT_TYPE_HEADER, @@ -392,6 +392,9 @@ class StreamableTransport(McpTransportBase): """Update MCP response info in session metadata.""" session = self.session_store.get_session(session_id) self.update_mcp_server_in_session_metadata(session, response_body) + session.attributes.metadata["is_stateless_http_server"] = session_id.startswith( + INVARIANT_SESSION_ID_PREFIX + ) session.attributes.metadata["server_response_type"] = ( "json" if is_json_response else "sse" ) @@ -402,7 +405,7 @@ class StreamableTransport(McpTransportBase): for k, v in request.headers.items(): if k.startswith(MCP_CUSTOM_HEADER_PREFIX): filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v - if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not ( + if k.lower() in MCP_SERVER_POST_AND_DELETE_HEADERS and not ( k.lower() == MCP_SESSION_ID_HEADER and v.startswith(INVARIANT_SESSION_ID_PREFIX) ): diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index 2606c2b..e65262b 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -2,7 +2,6 @@ import os import uuid - from resources.mcp.sse.client.main import run as mcp_sse_client_run from resources.mcp.stdio.client.main import run as mcp_stdio_client_run from resources.mcp.streamable.client.main import run as mcp_streamable_client_run @@ -12,8 +11,6 @@ import httpx import pytest import requests -from mcp.shared.exceptions import McpError - # Taken from docker-compose.test.yml MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server" MCP_SSE_SERVER_PORT = 8123