From 6849fc7daa52ac4cda0568a82f359b44106da457 Mon Sep 17 00:00:00 2001 From: Hemang Date: Mon, 2 Jun 2025 11:59:14 +0200 Subject: [PATCH] Update McpSession class so that it can be used from both sse/streamable and stdio transports. Also update SseHeaderAttributes to McpAttributes so that it be can be used different MCP transports. --- gateway/common/mcp_sessions_manager.py | 143 ++++++++++++++++++------- gateway/mcp/mcp_context.py | 3 +- gateway/routes/mcp_sse.py | 14 +-- gateway/routes/mcp_streamable.py | 12 +-- 4 files changed, 120 insertions(+), 52 deletions(-) diff --git a/gateway/common/mcp_sessions_manager.py b/gateway/common/mcp_sessions_manager.py index 26f5b33..14a627f 100644 --- a/gateway/common/mcp_sessions_manager.py +++ b/gateway/common/mcp_sessions_manager.py @@ -1,5 +1,6 @@ """MCP Sessions Manager related classes""" +import argparse import asyncio import contextlib import getpass @@ -7,7 +8,7 @@ import os import random import socket -from typing import Any, Dict, List, Optional +from typing import Any, Optional from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.append_messages import AppendMessagesRequest @@ -41,15 +42,12 @@ class McpSession(BaseModel): """ session_id: str - messages: List[Dict[str, Any]] = Field(default_factory=list) - metadata: Dict[str, Any] = Field(default_factory=dict) - id_to_method_mapping: Dict[int, str] = Field(default_factory=dict) - explorer_dataset: str - push_explorer: bool - invariant_api_key: Optional[str] = None + messages: list[dict[str, Any]] = Field(default_factory=list) + attributes: Optional["McpAttributes"] = None + id_to_method_mapping: dict[int, str] = Field(default_factory=dict) trace_id: Optional[str] = None last_trace_length: int = 0 - annotations: List[Dict[str, Any]] = Field(default_factory=list) + annotations: list[dict[str, Any]] = Field(default_factory=list) guardrails: GuardrailRuleSet = Field( default_factory=lambda: GuardrailRuleSet( blocking_guardrails=[], logging_guardrails=[] @@ -57,35 +55,35 @@ class McpSession(BaseModel): ) # When tool calls are blocked, the error message is stored here # and sent to the client via the SSE stream. - pending_error_messages: List[dict] = Field(default_factory=list) + pending_error_messages: list[dict] = Field(default_factory=list) # Lock to maintain in-order pushes to explorer # and other session-related operations _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) def get_invariant_api_key(self) -> str: - """ - Get the Invariant API key for the session. - - Returns: - str: The Invariant API key - """ - if self.invariant_api_key: - return self.invariant_api_key + """Get the Invariant API key for the session.""" + if self.attributes.invariant_api_key: + return self.attributes.invariant_api_key return os.getenv("INVARIANT_API_KEY") + def get_invariant_authorization(self) -> str: + """Get the Invariant authorization header for the session.""" + return "Bearer " + self.get_invariant_api_key() + async def load_guardrails(self) -> None: """ Load guardrails for the session. This method fetches guardrails from the Invariant Explorer and assigns them to the session. """ + print("Inside load_guardrails attributes: ", self.attributes, flush=True) self.guardrails = await fetch_guardrails_from_explorer( - self.explorer_dataset, - "Bearer " + self.get_invariant_api_key(), + self.attributes.explorer_dataset, + self.get_invariant_authorization(), # pylint: disable=no-member - self.metadata.get("mcp_client"), - self.metadata.get("mcp_server"), + self.attributes.metadata.get("mcp_client"), + self.attributes.metadata.get("mcp_server"), ) def _deduplicate_annotations(self, new_annotations: list) -> list: @@ -113,9 +111,11 @@ class McpSession(BaseModel): metadata = { "session_id": self.session_id, "system_user": user_and_host(), - **(self.metadata or {}), + **(self.attributes.metadata or {}), } - metadata["is_stateless_http_server"] = self.session_id.startswith(INVARIANT_SESSION_ID_PREFIX) + metadata["is_stateless_http_server"] = self.session_id.startswith( + INVARIANT_SESSION_ID_PREFIX + ) return metadata async def get_guardrails_check_result( @@ -134,10 +134,15 @@ class McpSession(BaseModel): return {} # Prepare context and select appropriate guardrails + print( + "Inside get_guardrails_check_result attributes: ", + self.attributes, + flush=True, + ) context = RequestContext.create( request_json={}, - dataset_name=self.explorer_dataset, - invariant_authorization="Bearer " + self.get_invariant_api_key(), + dataset_name=self.attributes.explorer_dataset, + invariant_authorization=self.get_invariant_authorization(), guardrails=self.guardrails, guardrails_parameters={ "metadata": self.session_metadata(), @@ -159,13 +164,14 @@ class McpSession(BaseModel): return result async def add_message( - self, message: Dict[str, Any], guardrails_result=Dict + self, message: dict[str, Any], guardrails_result=dict ) -> None: """ Add a message to the session and optionally push to explorer. Args: message: The message to add + guardrails_result: The result of the guardrails check """ async with self.session_lock(): annotations = [] @@ -193,7 +199,7 @@ class McpSession(BaseModel): # pylint: disable=no-member self.messages.append(message) # If push_explorer is enabled, push the trace - if self.push_explorer: + if self.attributes.push_explorer: await self._push_trace_update(deduplicated_annotations) async def _push_trace_update(self, deduplicated_annotations: list) -> None: @@ -204,6 +210,7 @@ class McpSession(BaseModel): This is an internal method that should only be called within a lock. """ + print("Inside _push_trace_update attributes: ", self.attributes, flush=True) try: client = AsyncClient( api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL), @@ -220,7 +227,7 @@ class McpSession(BaseModel): response = await client.push_trace( PushTracesRequest( messages=[self.messages], - dataset=self.explorer_dataset, + dataset=self.attributes.explorer_dataset, metadata=[metadata], annotations=[deduplicated_annotations], ) @@ -253,12 +260,12 @@ class McpSession(BaseModel): # pylint: disable=no-member self.pending_error_messages.append(error_message) - async def get_pending_error_messages(self) -> List[dict]: + async def get_pending_error_messages(self) -> list[dict]: """ Get all pending error messages for the session. Returns: - List[dict]: A list of pending error messages + list[dict]: A list of pending error messages """ async with self.session_lock(): messages = list(self.pending_error_messages) @@ -266,17 +273,22 @@ class McpSession(BaseModel): return messages -class SseHeaderAttributes(BaseModel): +class McpAttributes(BaseModel): """ - A Pydantic model to represent header attributes. + A Pydantic model to represent MCP attributes. + This can be initialized using HTTP headers for SSE and Streamable transports. + This can also be initialized using CLI arguments for the Stdio transport. """ push_explorer: bool explorer_dataset: str invariant_api_key: Optional[str] = None + failure_response_format: Optional[str] = None + verbose: Optional[bool] = False + metadata: dict[str, Any] = Field(default_factory=dict) @classmethod - def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes": + def from_request_headers(cls, headers: Headers) -> "McpAttributes": """ Create an instance from FastAPI request headers. @@ -284,7 +296,7 @@ class SseHeaderAttributes(BaseModel): headers: FastAPI Request headers Returns: - SseHeaderAttributes: An instance with values extracted from headers + McpAttributes: An instance with values extracted from headers """ # Extract and process header values project_name = headers.get("INVARIANT-PROJECT-NAME") @@ -307,6 +319,61 @@ class SseHeaderAttributes(BaseModel): invariant_api_key=invariant_api_key, ) + @classmethod + def from_cli_args(cls, cli_args: list) -> "McpAttributes": + """ + Create an instance from command line arguments. + + Args: + cli_args: List of command line arguments + + Returns: + McpAttributes: An instance with values extracted from CLI arguments + """ + parser = argparse.ArgumentParser(description="MCP Gateway") + parser.add_argument( + "--project-name", + help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.", + type=str, + default=f"mcp-capture-{random.randint(1, 100)}", + ) + parser.add_argument( + "--push-explorer", + help="Enable pushing traces to Invariant Explorer", + action="store_true", + ) + parser.add_argument( + "--verbose", + help="Enable verbose logging", + action="store_true", + ) + parser.add_argument( + "--failure-response-format", + help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)", + type=str, + default="error", + ) + + config, extra_args = parser.parse_known_args(cli_args) + + metadata: dict[str, Any] = {} + for arg in extra_args: + assert "=" in arg, f"Invalid extra metadata argument: {arg}" + key, value = arg.split("=") + assert key.startswith( + "--metadata-" + ), f"Invalid extra metadata argument: {arg}, must start with --metadata-" + key = key[len("--metadata-") :] + metadata[key] = value + + return cls( + push_explorer=config.push_explorer, + explorer_dataset=config.project_name, + failure_response_format=config.failure_response_format, + verbose=config.verbose, + metadata=metadata, + ) + class McpSessionsManager: """ @@ -342,7 +409,7 @@ class McpSessionsManager: del self._session_locks[session_id] async def initialize_session( - self, session_id: str, sse_header_attributes: SseHeaderAttributes + self, session_id: str, attributes: McpAttributes ) -> None: """Initialize a new session""" # Get the lock for this specific session @@ -354,9 +421,7 @@ class McpSessionsManager: if session_id not in self._sessions: session = McpSession( session_id=session_id, - **sse_header_attributes.model_dump( - exclude_unset=True, - ), + attributes=attributes, ) self._sessions[session_id] = session # Load guardrails for the session from the explorer @@ -374,7 +439,7 @@ class McpSessionsManager: return self._sessions.get(session_id) async def add_message_to_session( - self, session_id: str, message: Dict[str, Any], guardrails_result: dict + self, session_id: str, message: dict[str, Any], guardrails_result: dict ) -> None: """ Add a message to a session and push to explorer if enabled. diff --git a/gateway/mcp/mcp_context.py b/gateway/mcp/mcp_context.py index c2c4e6b..4da2923 100644 --- a/gateway/mcp/mcp_context.py +++ b/gateway/mcp/mcp_context.py @@ -105,7 +105,8 @@ class McpContext: async def load_guardrails(self): """Run async setup logic (e.g. fetching guardrails).""" self.guardrails = await fetch_guardrails_from_explorer( - self.explorer_dataset, "Bearer " + os.getenv("INVARIANT_API_KEY"), + self.explorer_dataset, + "Bearer " + os.getenv("INVARIANT_API_KEY"), self.extra_metadata.get("client", self.mcp_client_name), self.extra_metadata.get("server", self.mcp_server_name), ) diff --git a/gateway/routes/mcp_sse.py b/gateway/routes/mcp_sse.py index c12d38e..4f71a6f 100644 --- a/gateway/routes/mcp_sse.py +++ b/gateway/routes/mcp_sse.py @@ -23,7 +23,7 @@ from gateway.common.constants import ( ) from gateway.common.mcp_sessions_manager import ( McpSessionsManager, - SseHeaderAttributes, + McpAttributes, ) from gateway.common.mcp_utils import ( get_mcp_server_base_url, @@ -79,7 +79,7 @@ async def mcp_post_sse_gateway( if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( MCP_CLIENT_INFO ): - session.metadata["mcp_client"] = ( + session.attributes.metadata["mcp_client"] = ( request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") ) @@ -137,6 +137,8 @@ async def mcp_post_sse_gateway( print(f"[MCP POST] Request error: {str(e)}") raise HTTPException(status_code=500, detail="Request error") from e except Exception as e: + import traceback + traceback.print_exc() print(f"[MCP POST] Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Unexpected error") from e @@ -153,7 +155,7 @@ async def mcp_get_sse_gateway( filtered_headers = { k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS } - sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers) + sse_header_attributes = McpAttributes.from_request_headers(request.headers) async def event_generator(): """ @@ -296,7 +298,7 @@ async def mcp_get_sse_gateway( async def _handle_endpoint_event( - sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes + sse: ServerSentEvent, sse_header_attributes: McpAttributes ) -> Tuple[bytes, str]: """ Handle the endpoint event type and modify the data accordingly. @@ -343,7 +345,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get( MCP_SERVER_INFO ): - session.metadata["mcp_server"] = ( + session.attributes.metadata["mcp_server"] = ( response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") ) @@ -364,7 +366,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes: ) elif method == MCP_LIST_TOOLS: # store tools in metadata - session_store.get_session(session_id).metadata["tools"] = response_json.get( + session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get( MCP_RESULT ).get("tools") # store tools/list tool call in trace diff --git a/gateway/routes/mcp_streamable.py b/gateway/routes/mcp_streamable.py index 6016ff6..bc5c2e8 100644 --- a/gateway/routes/mcp_streamable.py +++ b/gateway/routes/mcp_streamable.py @@ -24,7 +24,7 @@ from gateway.common.constants import ( ) from gateway.common.mcp_sessions_manager import ( McpSessionsManager, - SseHeaderAttributes, + McpAttributes, ) from gateway.common.mcp_utils import ( get_mcp_server_base_url, @@ -61,7 +61,7 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse: """ request_body_bytes = await request.body() request_body = json.loads(request_body_bytes) - sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers) + sse_header_attributes = McpAttributes.from_request_headers(request.headers) session_id = request.headers.get(MCP_SESSION_ID_HEADER) is_initialization_request = _is_initialization_request(request_body) @@ -308,7 +308,7 @@ def _update_mcp_client_info_in_session(session_id: str, request_body: dict) -> N if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get( MCP_CLIENT_INFO ): - session.metadata["mcp_client"] = ( + session.attributes.metadata["mcp_client"] = ( request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "") ) @@ -323,10 +323,10 @@ def _update_mcp_response_info_in_session( if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get( MCP_SERVER_INFO ): - session.metadata["mcp_server"] = ( + session.attributes.metadata["mcp_server"] = ( response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "") ) - session.metadata["server_response_type"] = "json" if is_json_response else "sse" + session.attributes.metadata["server_response_type"] = "json" if is_json_response else "sse" def _is_initialization_request(request_body: dict) -> bool: @@ -507,7 +507,7 @@ async def _intercept_response( # Intercept and potentially block list tool call response elif method == MCP_LIST_TOOLS: # store tools in metadata - session_store.get_session(session_id).metadata["tools"] = response_json.get( + session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get( MCP_RESULT ).get("tools") # store tools/list tool call in trace