Update McpSession class so that it can be used from both sse/streamable and stdio transports. Also update SseHeaderAttributes to McpAttributes so that it be can be used different MCP transports.

This commit is contained in:
Hemang
2025-06-02 11:59:14 +02:00
committed by Hemang Sarkar
parent 96826fa06d
commit 6849fc7daa
4 changed files with 120 additions and 52 deletions
+104 -39
View File
@@ -1,5 +1,6 @@
"""MCP Sessions Manager related classes"""
import argparse
import asyncio
import contextlib
import getpass
@@ -7,7 +8,7 @@ import os
import random
import socket
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
@@ -41,15 +42,12 @@ class McpSession(BaseModel):
"""
session_id: str
messages: List[Dict[str, Any]] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
id_to_method_mapping: Dict[int, str] = Field(default_factory=dict)
explorer_dataset: str
push_explorer: bool
invariant_api_key: Optional[str] = None
messages: list[dict[str, Any]] = Field(default_factory=list)
attributes: Optional["McpAttributes"] = None
id_to_method_mapping: dict[int, str] = Field(default_factory=dict)
trace_id: Optional[str] = None
last_trace_length: int = 0
annotations: List[Dict[str, Any]] = Field(default_factory=list)
annotations: list[dict[str, Any]] = Field(default_factory=list)
guardrails: GuardrailRuleSet = Field(
default_factory=lambda: GuardrailRuleSet(
blocking_guardrails=[], logging_guardrails=[]
@@ -57,35 +55,35 @@ class McpSession(BaseModel):
)
# When tool calls are blocked, the error message is stored here
# and sent to the client via the SSE stream.
pending_error_messages: List[dict] = Field(default_factory=list)
pending_error_messages: list[dict] = Field(default_factory=list)
# Lock to maintain in-order pushes to explorer
# and other session-related operations
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def get_invariant_api_key(self) -> str:
"""
Get the Invariant API key for the session.
Returns:
str: The Invariant API key
"""
if self.invariant_api_key:
return self.invariant_api_key
"""Get the Invariant API key for the session."""
if self.attributes.invariant_api_key:
return self.attributes.invariant_api_key
return os.getenv("INVARIANT_API_KEY")
def get_invariant_authorization(self) -> str:
"""Get the Invariant authorization header for the session."""
return "Bearer " + self.get_invariant_api_key()
async def load_guardrails(self) -> None:
"""
Load guardrails for the session.
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
"""
print("Inside load_guardrails attributes: ", self.attributes, flush=True)
self.guardrails = await fetch_guardrails_from_explorer(
self.explorer_dataset,
"Bearer " + self.get_invariant_api_key(),
self.attributes.explorer_dataset,
self.get_invariant_authorization(),
# pylint: disable=no-member
self.metadata.get("mcp_client"),
self.metadata.get("mcp_server"),
self.attributes.metadata.get("mcp_client"),
self.attributes.metadata.get("mcp_server"),
)
def _deduplicate_annotations(self, new_annotations: list) -> list:
@@ -113,9 +111,11 @@ class McpSession(BaseModel):
metadata = {
"session_id": self.session_id,
"system_user": user_and_host(),
**(self.metadata or {}),
**(self.attributes.metadata or {}),
}
metadata["is_stateless_http_server"] = self.session_id.startswith(INVARIANT_SESSION_ID_PREFIX)
metadata["is_stateless_http_server"] = self.session_id.startswith(
INVARIANT_SESSION_ID_PREFIX
)
return metadata
async def get_guardrails_check_result(
@@ -134,10 +134,15 @@ class McpSession(BaseModel):
return {}
# Prepare context and select appropriate guardrails
print(
"Inside get_guardrails_check_result attributes: ",
self.attributes,
flush=True,
)
context = RequestContext.create(
request_json={},
dataset_name=self.explorer_dataset,
invariant_authorization="Bearer " + self.get_invariant_api_key(),
dataset_name=self.attributes.explorer_dataset,
invariant_authorization=self.get_invariant_authorization(),
guardrails=self.guardrails,
guardrails_parameters={
"metadata": self.session_metadata(),
@@ -159,13 +164,14 @@ class McpSession(BaseModel):
return result
async def add_message(
self, message: Dict[str, Any], guardrails_result=Dict
self, message: dict[str, Any], guardrails_result=dict
) -> None:
"""
Add a message to the session and optionally push to explorer.
Args:
message: The message to add
guardrails_result: The result of the guardrails check
"""
async with self.session_lock():
annotations = []
@@ -193,7 +199,7 @@ class McpSession(BaseModel):
# pylint: disable=no-member
self.messages.append(message)
# If push_explorer is enabled, push the trace
if self.push_explorer:
if self.attributes.push_explorer:
await self._push_trace_update(deduplicated_annotations)
async def _push_trace_update(self, deduplicated_annotations: list) -> None:
@@ -204,6 +210,7 @@ class McpSession(BaseModel):
This is an internal method that should only be called within a lock.
"""
print("Inside _push_trace_update attributes: ", self.attributes, flush=True)
try:
client = AsyncClient(
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
@@ -220,7 +227,7 @@ class McpSession(BaseModel):
response = await client.push_trace(
PushTracesRequest(
messages=[self.messages],
dataset=self.explorer_dataset,
dataset=self.attributes.explorer_dataset,
metadata=[metadata],
annotations=[deduplicated_annotations],
)
@@ -253,12 +260,12 @@ class McpSession(BaseModel):
# pylint: disable=no-member
self.pending_error_messages.append(error_message)
async def get_pending_error_messages(self) -> List[dict]:
async def get_pending_error_messages(self) -> list[dict]:
"""
Get all pending error messages for the session.
Returns:
List[dict]: A list of pending error messages
list[dict]: A list of pending error messages
"""
async with self.session_lock():
messages = list(self.pending_error_messages)
@@ -266,17 +273,22 @@ class McpSession(BaseModel):
return messages
class SseHeaderAttributes(BaseModel):
class McpAttributes(BaseModel):
"""
A Pydantic model to represent header attributes.
A Pydantic model to represent MCP attributes.
This can be initialized using HTTP headers for SSE and Streamable transports.
This can also be initialized using CLI arguments for the Stdio transport.
"""
push_explorer: bool
explorer_dataset: str
invariant_api_key: Optional[str] = None
failure_response_format: Optional[str] = None
verbose: Optional[bool] = False
metadata: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
def from_request_headers(cls, headers: Headers) -> "McpAttributes":
"""
Create an instance from FastAPI request headers.
@@ -284,7 +296,7 @@ class SseHeaderAttributes(BaseModel):
headers: FastAPI Request headers
Returns:
SseHeaderAttributes: An instance with values extracted from headers
McpAttributes: An instance with values extracted from headers
"""
# Extract and process header values
project_name = headers.get("INVARIANT-PROJECT-NAME")
@@ -307,6 +319,61 @@ class SseHeaderAttributes(BaseModel):
invariant_api_key=invariant_api_key,
)
@classmethod
def from_cli_args(cls, cli_args: list) -> "McpAttributes":
"""
Create an instance from command line arguments.
Args:
cli_args: List of command line arguments
Returns:
McpAttributes: An instance with values extracted from CLI arguments
"""
parser = argparse.ArgumentParser(description="MCP Gateway")
parser.add_argument(
"--project-name",
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
type=str,
default=f"mcp-capture-{random.randint(1, 100)}",
)
parser.add_argument(
"--push-explorer",
help="Enable pushing traces to Invariant Explorer",
action="store_true",
)
parser.add_argument(
"--verbose",
help="Enable verbose logging",
action="store_true",
)
parser.add_argument(
"--failure-response-format",
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
type=str,
default="error",
)
config, extra_args = parser.parse_known_args(cli_args)
metadata: dict[str, Any] = {}
for arg in extra_args:
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
key, value = arg.split("=")
assert key.startswith(
"--metadata-"
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
key = key[len("--metadata-") :]
metadata[key] = value
return cls(
push_explorer=config.push_explorer,
explorer_dataset=config.project_name,
failure_response_format=config.failure_response_format,
verbose=config.verbose,
metadata=metadata,
)
class McpSessionsManager:
"""
@@ -342,7 +409,7 @@ class McpSessionsManager:
del self._session_locks[session_id]
async def initialize_session(
self, session_id: str, sse_header_attributes: SseHeaderAttributes
self, session_id: str, attributes: McpAttributes
) -> None:
"""Initialize a new session"""
# Get the lock for this specific session
@@ -354,9 +421,7 @@ class McpSessionsManager:
if session_id not in self._sessions:
session = McpSession(
session_id=session_id,
**sse_header_attributes.model_dump(
exclude_unset=True,
),
attributes=attributes,
)
self._sessions[session_id] = session
# Load guardrails for the session from the explorer
@@ -374,7 +439,7 @@ class McpSessionsManager:
return self._sessions.get(session_id)
async def add_message_to_session(
self, session_id: str, message: Dict[str, Any], guardrails_result: dict
self, session_id: str, message: dict[str, Any], guardrails_result: dict
) -> None:
"""
Add a message to a session and push to explorer if enabled.
+2 -1
View File
@@ -105,7 +105,8 @@ class McpContext:
async def load_guardrails(self):
"""Run async setup logic (e.g. fetching guardrails)."""
self.guardrails = await fetch_guardrails_from_explorer(
self.explorer_dataset, "Bearer " + os.getenv("INVARIANT_API_KEY"),
self.explorer_dataset,
"Bearer " + os.getenv("INVARIANT_API_KEY"),
self.extra_metadata.get("client", self.mcp_client_name),
self.extra_metadata.get("server", self.mcp_server_name),
)
+8 -6
View File
@@ -23,7 +23,7 @@ from gateway.common.constants import (
)
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
McpAttributes,
)
from gateway.common.mcp_utils import (
get_mcp_server_base_url,
@@ -79,7 +79,7 @@ async def mcp_post_sse_gateway(
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.metadata["mcp_client"] = (
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
@@ -137,6 +137,8 @@ async def mcp_post_sse_gateway(
print(f"[MCP POST] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
except Exception as e:
import traceback
traceback.print_exc()
print(f"[MCP POST] Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Unexpected error") from e
@@ -153,7 +155,7 @@ async def mcp_get_sse_gateway(
filtered_headers = {
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
}
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
sse_header_attributes = McpAttributes.from_request_headers(request.headers)
async def event_generator():
"""
@@ -296,7 +298,7 @@ async def mcp_get_sse_gateway(
async def _handle_endpoint_event(
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
sse: ServerSentEvent, sse_header_attributes: McpAttributes
) -> Tuple[bytes, str]:
"""
Handle the endpoint event type and modify the data accordingly.
@@ -343,7 +345,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.metadata["mcp_server"] = (
session.attributes.metadata["mcp_server"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
@@ -364,7 +366,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
)
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).metadata["tools"] = response_json.get(
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace
+6 -6
View File
@@ -24,7 +24,7 @@ from gateway.common.constants import (
)
from gateway.common.mcp_sessions_manager import (
McpSessionsManager,
SseHeaderAttributes,
McpAttributes,
)
from gateway.common.mcp_utils import (
get_mcp_server_base_url,
@@ -61,7 +61,7 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
"""
request_body_bytes = await request.body()
request_body = json.loads(request_body_bytes)
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
sse_header_attributes = McpAttributes.from_request_headers(request.headers)
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
is_initialization_request = _is_initialization_request(request_body)
@@ -308,7 +308,7 @@ def _update_mcp_client_info_in_session(session_id: str, request_body: dict) -> N
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
MCP_CLIENT_INFO
):
session.metadata["mcp_client"] = (
session.attributes.metadata["mcp_client"] = (
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
)
@@ -323,10 +323,10 @@ def _update_mcp_response_info_in_session(
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
MCP_SERVER_INFO
):
session.metadata["mcp_server"] = (
session.attributes.metadata["mcp_server"] = (
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
)
session.metadata["server_response_type"] = "json" if is_json_response else "sse"
session.attributes.metadata["server_response_type"] = "json" if is_json_response else "sse"
def _is_initialization_request(request_body: dict) -> bool:
@@ -507,7 +507,7 @@ async def _intercept_response(
# Intercept and potentially block list tool call response
elif method == MCP_LIST_TOOLS:
# store tools in metadata
session_store.get_session(session_id).metadata["tools"] = response_json.get(
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
MCP_RESULT
).get("tools")
# store tools/list tool call in trace