mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-07-02 17:15:41 +02:00
Update McpSession class so that it can be used from both sse/streamable and stdio transports. Also update SseHeaderAttributes to McpAttributes so that it be can be used different MCP transports.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""MCP Sessions Manager related classes"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import getpass
|
||||
@@ -7,7 +8,7 @@ import os
|
||||
import random
|
||||
import socket
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
@@ -41,15 +42,12 @@ class McpSession(BaseModel):
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
id_to_method_mapping: Dict[int, str] = Field(default_factory=dict)
|
||||
explorer_dataset: str
|
||||
push_explorer: bool
|
||||
invariant_api_key: Optional[str] = None
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
attributes: Optional["McpAttributes"] = None
|
||||
id_to_method_mapping: dict[int, str] = Field(default_factory=dict)
|
||||
trace_id: Optional[str] = None
|
||||
last_trace_length: int = 0
|
||||
annotations: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
annotations: list[dict[str, Any]] = Field(default_factory=list)
|
||||
guardrails: GuardrailRuleSet = Field(
|
||||
default_factory=lambda: GuardrailRuleSet(
|
||||
blocking_guardrails=[], logging_guardrails=[]
|
||||
@@ -57,35 +55,35 @@ class McpSession(BaseModel):
|
||||
)
|
||||
# When tool calls are blocked, the error message is stored here
|
||||
# and sent to the client via the SSE stream.
|
||||
pending_error_messages: List[dict] = Field(default_factory=list)
|
||||
pending_error_messages: list[dict] = Field(default_factory=list)
|
||||
|
||||
# Lock to maintain in-order pushes to explorer
|
||||
# and other session-related operations
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
def get_invariant_api_key(self) -> str:
|
||||
"""
|
||||
Get the Invariant API key for the session.
|
||||
|
||||
Returns:
|
||||
str: The Invariant API key
|
||||
"""
|
||||
if self.invariant_api_key:
|
||||
return self.invariant_api_key
|
||||
"""Get the Invariant API key for the session."""
|
||||
if self.attributes.invariant_api_key:
|
||||
return self.attributes.invariant_api_key
|
||||
return os.getenv("INVARIANT_API_KEY")
|
||||
|
||||
def get_invariant_authorization(self) -> str:
|
||||
"""Get the Invariant authorization header for the session."""
|
||||
return "Bearer " + self.get_invariant_api_key()
|
||||
|
||||
async def load_guardrails(self) -> None:
|
||||
"""
|
||||
Load guardrails for the session.
|
||||
|
||||
This method fetches guardrails from the Invariant Explorer and assigns them to the session.
|
||||
"""
|
||||
print("Inside load_guardrails attributes: ", self.attributes, flush=True)
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.explorer_dataset,
|
||||
"Bearer " + self.get_invariant_api_key(),
|
||||
self.attributes.explorer_dataset,
|
||||
self.get_invariant_authorization(),
|
||||
# pylint: disable=no-member
|
||||
self.metadata.get("mcp_client"),
|
||||
self.metadata.get("mcp_server"),
|
||||
self.attributes.metadata.get("mcp_client"),
|
||||
self.attributes.metadata.get("mcp_server"),
|
||||
)
|
||||
|
||||
def _deduplicate_annotations(self, new_annotations: list) -> list:
|
||||
@@ -113,9 +111,11 @@ class McpSession(BaseModel):
|
||||
metadata = {
|
||||
"session_id": self.session_id,
|
||||
"system_user": user_and_host(),
|
||||
**(self.metadata or {}),
|
||||
**(self.attributes.metadata or {}),
|
||||
}
|
||||
metadata["is_stateless_http_server"] = self.session_id.startswith(INVARIANT_SESSION_ID_PREFIX)
|
||||
metadata["is_stateless_http_server"] = self.session_id.startswith(
|
||||
INVARIANT_SESSION_ID_PREFIX
|
||||
)
|
||||
return metadata
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
@@ -134,10 +134,15 @@ class McpSession(BaseModel):
|
||||
return {}
|
||||
|
||||
# Prepare context and select appropriate guardrails
|
||||
print(
|
||||
"Inside get_guardrails_check_result attributes: ",
|
||||
self.attributes,
|
||||
flush=True,
|
||||
)
|
||||
context = RequestContext.create(
|
||||
request_json={},
|
||||
dataset_name=self.explorer_dataset,
|
||||
invariant_authorization="Bearer " + self.get_invariant_api_key(),
|
||||
dataset_name=self.attributes.explorer_dataset,
|
||||
invariant_authorization=self.get_invariant_authorization(),
|
||||
guardrails=self.guardrails,
|
||||
guardrails_parameters={
|
||||
"metadata": self.session_metadata(),
|
||||
@@ -159,13 +164,14 @@ class McpSession(BaseModel):
|
||||
return result
|
||||
|
||||
async def add_message(
|
||||
self, message: Dict[str, Any], guardrails_result=Dict
|
||||
self, message: dict[str, Any], guardrails_result=dict
|
||||
) -> None:
|
||||
"""
|
||||
Add a message to the session and optionally push to explorer.
|
||||
|
||||
Args:
|
||||
message: The message to add
|
||||
guardrails_result: The result of the guardrails check
|
||||
"""
|
||||
async with self.session_lock():
|
||||
annotations = []
|
||||
@@ -193,7 +199,7 @@ class McpSession(BaseModel):
|
||||
# pylint: disable=no-member
|
||||
self.messages.append(message)
|
||||
# If push_explorer is enabled, push the trace
|
||||
if self.push_explorer:
|
||||
if self.attributes.push_explorer:
|
||||
await self._push_trace_update(deduplicated_annotations)
|
||||
|
||||
async def _push_trace_update(self, deduplicated_annotations: list) -> None:
|
||||
@@ -204,6 +210,7 @@ class McpSession(BaseModel):
|
||||
|
||||
This is an internal method that should only be called within a lock.
|
||||
"""
|
||||
print("Inside _push_trace_update attributes: ", self.attributes, flush=True)
|
||||
try:
|
||||
client = AsyncClient(
|
||||
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
|
||||
@@ -220,7 +227,7 @@ class McpSession(BaseModel):
|
||||
response = await client.push_trace(
|
||||
PushTracesRequest(
|
||||
messages=[self.messages],
|
||||
dataset=self.explorer_dataset,
|
||||
dataset=self.attributes.explorer_dataset,
|
||||
metadata=[metadata],
|
||||
annotations=[deduplicated_annotations],
|
||||
)
|
||||
@@ -253,12 +260,12 @@ class McpSession(BaseModel):
|
||||
# pylint: disable=no-member
|
||||
self.pending_error_messages.append(error_message)
|
||||
|
||||
async def get_pending_error_messages(self) -> List[dict]:
|
||||
async def get_pending_error_messages(self) -> list[dict]:
|
||||
"""
|
||||
Get all pending error messages for the session.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of pending error messages
|
||||
list[dict]: A list of pending error messages
|
||||
"""
|
||||
async with self.session_lock():
|
||||
messages = list(self.pending_error_messages)
|
||||
@@ -266,17 +273,22 @@ class McpSession(BaseModel):
|
||||
return messages
|
||||
|
||||
|
||||
class SseHeaderAttributes(BaseModel):
|
||||
class McpAttributes(BaseModel):
|
||||
"""
|
||||
A Pydantic model to represent header attributes.
|
||||
A Pydantic model to represent MCP attributes.
|
||||
This can be initialized using HTTP headers for SSE and Streamable transports.
|
||||
This can also be initialized using CLI arguments for the Stdio transport.
|
||||
"""
|
||||
|
||||
push_explorer: bool
|
||||
explorer_dataset: str
|
||||
invariant_api_key: Optional[str] = None
|
||||
failure_response_format: Optional[str] = None
|
||||
verbose: Optional[bool] = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
|
||||
def from_request_headers(cls, headers: Headers) -> "McpAttributes":
|
||||
"""
|
||||
Create an instance from FastAPI request headers.
|
||||
|
||||
@@ -284,7 +296,7 @@ class SseHeaderAttributes(BaseModel):
|
||||
headers: FastAPI Request headers
|
||||
|
||||
Returns:
|
||||
SseHeaderAttributes: An instance with values extracted from headers
|
||||
McpAttributes: An instance with values extracted from headers
|
||||
"""
|
||||
# Extract and process header values
|
||||
project_name = headers.get("INVARIANT-PROJECT-NAME")
|
||||
@@ -307,6 +319,61 @@ class SseHeaderAttributes(BaseModel):
|
||||
invariant_api_key=invariant_api_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, cli_args: list) -> "McpAttributes":
|
||||
"""
|
||||
Create an instance from command line arguments.
|
||||
|
||||
Args:
|
||||
cli_args: List of command line arguments
|
||||
|
||||
Returns:
|
||||
McpAttributes: An instance with values extracted from CLI arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MCP Gateway")
|
||||
parser.add_argument(
|
||||
"--project-name",
|
||||
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
|
||||
type=str,
|
||||
default=f"mcp-capture-{random.randint(1, 100)}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-explorer",
|
||||
help="Enable pushing traces to Invariant Explorer",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
help="Enable verbose logging",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--failure-response-format",
|
||||
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
|
||||
type=str,
|
||||
default="error",
|
||||
)
|
||||
|
||||
config, extra_args = parser.parse_known_args(cli_args)
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
for arg in extra_args:
|
||||
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
|
||||
key, value = arg.split("=")
|
||||
assert key.startswith(
|
||||
"--metadata-"
|
||||
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
|
||||
key = key[len("--metadata-") :]
|
||||
metadata[key] = value
|
||||
|
||||
return cls(
|
||||
push_explorer=config.push_explorer,
|
||||
explorer_dataset=config.project_name,
|
||||
failure_response_format=config.failure_response_format,
|
||||
verbose=config.verbose,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class McpSessionsManager:
|
||||
"""
|
||||
@@ -342,7 +409,7 @@ class McpSessionsManager:
|
||||
del self._session_locks[session_id]
|
||||
|
||||
async def initialize_session(
|
||||
self, session_id: str, sse_header_attributes: SseHeaderAttributes
|
||||
self, session_id: str, attributes: McpAttributes
|
||||
) -> None:
|
||||
"""Initialize a new session"""
|
||||
# Get the lock for this specific session
|
||||
@@ -354,9 +421,7 @@ class McpSessionsManager:
|
||||
if session_id not in self._sessions:
|
||||
session = McpSession(
|
||||
session_id=session_id,
|
||||
**sse_header_attributes.model_dump(
|
||||
exclude_unset=True,
|
||||
),
|
||||
attributes=attributes,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
# Load guardrails for the session from the explorer
|
||||
@@ -374,7 +439,7 @@ class McpSessionsManager:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
async def add_message_to_session(
|
||||
self, session_id: str, message: Dict[str, Any], guardrails_result: dict
|
||||
self, session_id: str, message: dict[str, Any], guardrails_result: dict
|
||||
) -> None:
|
||||
"""
|
||||
Add a message to a session and push to explorer if enabled.
|
||||
|
||||
@@ -105,7 +105,8 @@ class McpContext:
|
||||
async def load_guardrails(self):
|
||||
"""Run async setup logic (e.g. fetching guardrails)."""
|
||||
self.guardrails = await fetch_guardrails_from_explorer(
|
||||
self.explorer_dataset, "Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
self.explorer_dataset,
|
||||
"Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
self.extra_metadata.get("client", self.mcp_client_name),
|
||||
self.extra_metadata.get("server", self.mcp_server_name),
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ from gateway.common.constants import (
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
get_mcp_server_base_url,
|
||||
@@ -79,7 +79,7 @@ async def mcp_post_sse_gateway(
|
||||
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.metadata["mcp_client"] = (
|
||||
session.attributes.metadata["mcp_client"] = (
|
||||
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
@@ -137,6 +137,8 @@ async def mcp_post_sse_gateway(
|
||||
print(f"[MCP POST] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"[MCP POST] Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error") from e
|
||||
|
||||
@@ -153,7 +155,7 @@ async def mcp_get_sse_gateway(
|
||||
filtered_headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() in MCP_SERVER_SSE_HEADERS
|
||||
}
|
||||
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
|
||||
sse_header_attributes = McpAttributes.from_request_headers(request.headers)
|
||||
|
||||
async def event_generator():
|
||||
"""
|
||||
@@ -296,7 +298,7 @@ async def mcp_get_sse_gateway(
|
||||
|
||||
|
||||
async def _handle_endpoint_event(
|
||||
sse: ServerSentEvent, sse_header_attributes: SseHeaderAttributes
|
||||
sse: ServerSentEvent, sse_header_attributes: McpAttributes
|
||||
) -> Tuple[bytes, str]:
|
||||
"""
|
||||
Handle the endpoint event type and modify the data accordingly.
|
||||
@@ -343,7 +345,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.metadata["mcp_server"] = (
|
||||
session.attributes.metadata["mcp_server"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
|
||||
@@ -364,7 +366,7 @@ async def _handle_message_event(session_id: str, sse: ServerSentEvent) -> bytes:
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
session_store.get_session(session_id).metadata["tools"] = response_json.get(
|
||||
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
|
||||
@@ -24,7 +24,7 @@ from gateway.common.constants import (
|
||||
)
|
||||
from gateway.common.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
SseHeaderAttributes,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.common.mcp_utils import (
|
||||
get_mcp_server_base_url,
|
||||
@@ -61,7 +61,7 @@ async def mcp_post_streamable_gateway(request: Request) -> StreamingResponse:
|
||||
"""
|
||||
request_body_bytes = await request.body()
|
||||
request_body = json.loads(request_body_bytes)
|
||||
sse_header_attributes = SseHeaderAttributes.from_request_headers(request.headers)
|
||||
sse_header_attributes = McpAttributes.from_request_headers(request.headers)
|
||||
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
is_initialization_request = _is_initialization_request(request_body)
|
||||
|
||||
@@ -308,7 +308,7 @@ def _update_mcp_client_info_in_session(session_id: str, request_body: dict) -> N
|
||||
if request_body.get(MCP_PARAMS) and request_body.get(MCP_PARAMS).get(
|
||||
MCP_CLIENT_INFO
|
||||
):
|
||||
session.metadata["mcp_client"] = (
|
||||
session.attributes.metadata["mcp_client"] = (
|
||||
request_body.get(MCP_PARAMS).get(MCP_CLIENT_INFO).get("name", "")
|
||||
)
|
||||
|
||||
@@ -323,10 +323,10 @@ def _update_mcp_response_info_in_session(
|
||||
if response_json.get(MCP_RESULT) and response_json.get(MCP_RESULT).get(
|
||||
MCP_SERVER_INFO
|
||||
):
|
||||
session.metadata["mcp_server"] = (
|
||||
session.attributes.metadata["mcp_server"] = (
|
||||
response_json.get(MCP_RESULT).get(MCP_SERVER_INFO).get("name", "")
|
||||
)
|
||||
session.metadata["server_response_type"] = "json" if is_json_response else "sse"
|
||||
session.attributes.metadata["server_response_type"] = "json" if is_json_response else "sse"
|
||||
|
||||
|
||||
def _is_initialization_request(request_body: dict) -> bool:
|
||||
@@ -507,7 +507,7 @@ async def _intercept_response(
|
||||
# Intercept and potentially block list tool call response
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
# store tools in metadata
|
||||
session_store.get_session(session_id).metadata["tools"] = response_json.get(
|
||||
session_store.get_session(session_id).attributes.metadata["tools"] = response_json.get(
|
||||
MCP_RESULT
|
||||
).get("tools")
|
||||
# store tools/list tool call in trace
|
||||
|
||||
Reference in New Issue
Block a user