Merge remote-tracking branch 'origin/main' into feature-message-timestamps

This commit is contained in:
knielsen404
2025-06-05 14:51:57 +02:00
25 changed files with 342 additions and 365 deletions
+2 -4
View File
@@ -7,8 +7,6 @@ import subprocess
import sys
import time
from typing import Optional
from gateway.mcp import stdio as mcp_stdio
from gateway.mcp.log import mcp_log
@@ -64,7 +62,7 @@ def ensure_network_exists(network_name: str = "invariant-explorer-web") -> bool:
return False
def setup_guardrails(guardrails_file_path: Optional[str] = None) -> bool:
def setup_guardrails(guardrails_file_path: str | None = None) -> bool:
"""Configure guardrails if specified."""
if not guardrails_file_path:
return True
@@ -105,7 +103,7 @@ def build():
return False
def up(guardrails_file_path: Optional[str] = None):
def up(guardrails_file_path: str | None = None):
"""Set up the local server for the Invariant Gateway."""
# Ensure network exists
if not ensure_network_exists():
+5 -6
View File
@@ -1,6 +1,5 @@
"""Common Authorization functions used in the gateway."""
from typing import Tuple, Optional
from fastapi import HTTPException, Request
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
@@ -10,7 +9,7 @@ API_KEYS_SEPARATOR = ";invariant-auth="
def extract_guardrail_service_authorization_from_headers(
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
) -> tuple[str | None, str | None]:
"""
Extracts the optional Invariant-Guardrails-Authorization authorization header from the request.
@@ -22,10 +21,10 @@ def extract_guardrail_service_authorization_from_headers(
def extract_authorization_from_headers(
request: Request,
dataset_name: Optional[str] = None,
llm_provider_api_key_header: Optional[str] = None,
llm_provider_fallback_api_key_headers: Optional[list[str]] = None,
) -> Tuple[Optional[str], Optional[str]]:
dataset_name: str | None = None,
llm_provider_api_key_header: str | None = None,
llm_provider_fallback_api_key_headers: list[str] | None = None,
) -> tuple[str | None, str | None]:
"""
Extracts the Invariant authorization and LLM Provider API key from the request headers.
+4 -4
View File
@@ -3,7 +3,6 @@
import asyncio
import os
import threading
from typing import Optional
import fastapi
from httpx import HTTPStatusError
@@ -11,7 +10,7 @@ from httpx import HTTPStatusError
from gateway.common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet
def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]:
def extract_policy_from_headers(request: fastapi.Request | None) -> str | None:
"""
Extracts the guardrailing policy from the request headers if present.
@@ -79,7 +78,7 @@ class GatewayConfigManager:
_lock = threading.Lock()
@classmethod
def get_config(cls, request: fastapi.Request = None) -> GatewayConfig:
def get_config(cls) -> GatewayConfig:
"""Initializes and returns the gateway configuration using double-checked locking."""
local_config = cls._config_instance
@@ -95,7 +94,7 @@ class GatewayConfigManager:
async def extract_guardrails_from_header(
request: fastapi.Request,
) -> Optional[GuardrailRuleSet]:
) -> GuardrailRuleSet | None:
"""
Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding
GuardrailRuleSet. If no guardrails are provided, returns None.
@@ -115,3 +114,4 @@ async def extract_guardrails_from_header(
blocking_guardrails=guardrails,
logging_guardrails=[],
)
return None
+3
View File
@@ -16,3 +16,6 @@ IGNORED_HEADERS = [
CLIENT_TIMEOUT = 60.0
CONTENT_TYPE_HEADER = "content-type"
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_EVENT_STREAM = "text/event-stream"
+3 -6
View File
@@ -1,10 +1,7 @@
"""Common guardrails data class."""
from enum import Enum
from typing import List
from dataclasses import dataclass
from enum import Enum
class GuardrailAction(str, Enum):
"""Enum representing the action to be taken for guardrail rules."""
@@ -27,5 +24,5 @@ class Guardrail:
class GuardrailRuleSet:
"""Grouped guardrail rules separated by their action."""
blocking_guardrails: List[Guardrail]
logging_guardrails: List[Guardrail]
blocking_guardrails: list[Guardrail]
logging_guardrails: list[Guardrail]
+20 -20
View File
@@ -1,33 +1,33 @@
"""Common Request context data class."""
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from typing import Any
import fastapi
from gateway.common.config_manager import GatewayConfig
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from gateway.common.authorization import (
extract_guardrail_service_authorization_from_headers,
)
from gateway.common.config_manager import GatewayConfig
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
@dataclass(frozen=True)
class RequestContext:
"""Structured context for a request. Must be created via `RequestContext.create()`."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
request_json: dict[str, Any]
dataset_name: str | None = None
# authorization to use for invariant service like explorer
invariant_authorization: Optional[str] = None
invariant_authorization: str | None = None
# authorization to use for invariant guardrailing specifically
guardrail_authorization: Optional[str] = None
guardrail_authorization: str | None = None
# the set of guardrails to enforce for this request
guardrails: Optional[GuardrailRuleSet] = None
config: Dict[str, Any] = None
guardrails: GuardrailRuleSet | None = None
config: dict[str, Any] | None = None
# extra parameters available as input.<key> during guardrail evaluation
guardrails_parameters: Optional[Dict[str, Any]] = None
guardrails_parameters: dict[str, Any] | None = None
_created_via_factory: bool = field(
default=False, init=True, repr=False, compare=False
@@ -42,13 +42,13 @@ class RequestContext:
@classmethod
def create(
cls,
request_json: Dict[str, Any],
dataset_name: Optional[str] = None,
invariant_authorization: Optional[str] = None,
guardrails: Optional[GuardrailRuleSet] = None,
config: Optional[GatewayConfig] = None,
request: fastapi.Request = None,
guardrails_parameters: Optional[Dict[str, Any]] = None,
request_json: dict[str, Any],
dataset_name: str | None = None,
invariant_authorization: str | None = None,
guardrails: GuardrailRuleSet | None = None,
config: GatewayConfig | None = None,
request: fastapi.Request | None = None,
guardrails_parameters: dict[str, Any] | None = None,
) -> "RequestContext":
"""Creates a new RequestContext instance, applying default guardrails if needed."""
@@ -100,10 +100,10 @@ class RequestContext:
guardrails=guardrails,
config=context_config,
_created_via_factory=True,
guardrails_parameters=guardrails_parameters
guardrails_parameters=guardrails_parameters,
)
def get_guardrailing_authorization(self) -> Optional[str]:
def get_guardrailing_authorization(self) -> str | None:
"""
Returns the authorization to use for the guardrailing service.
+10 -11
View File
@@ -2,8 +2,9 @@
import os
import json
from typing import Any
from typing import Any, Dict, List
import httpx
from fastapi import HTTPException
from gateway.common.constants import DEFAULT_API_URL
@@ -12,12 +13,10 @@ from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
from invariant_sdk.types.annotations import AnnotationCreate
import httpx
def create_annotations_from_guardrails_errors(
guardrails_errors: List[dict],
) -> List[AnnotationCreate]:
guardrails_errors: list[dict],
) -> list[AnnotationCreate]:
"""Create Explorer annotations from the guardrails errors."""
annotations = []
@@ -68,7 +67,7 @@ def create_annotations_from_guardrails_errors(
return remove_duplicates(annotations)
def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]:
def remove_duplicates(annotations: list[AnnotationCreate]) -> list[AnnotationCreate]:
"""
Remove duplicate annotations based on content, address, and extra_metadata.
@@ -99,18 +98,18 @@ def get_explorer_api_url() -> str:
async def push_trace(
messages: List[List[Dict[str, Any]]],
messages: list[list[dict[str, Any]]],
dataset_name: str,
invariant_authorization: str,
annotations: List[List[AnnotationCreate]] = None,
metadata: List[Dict[str, Any]] = None,
annotations: list[list[AnnotationCreate]] | None = None,
metadata: list[dict[str, Any]] | None = None,
) -> PushTracesResponse:
"""Pushes traces to the dataset on the Invariant Explorer.
If a dataset with the given name does not exist, it will be created.
Args:
messages (List[List[Dict[str, Any]]]): List of messages to push.
messages (listlistdict[str, Any]]]): List of messages to push.
dataset_name (str): Name of the dataset.
invariant_authorization (str): Value of the
invariant-authorization header.
@@ -135,7 +134,7 @@ async def push_trace(
)
try:
return await client.push_trace(request)
except Exception as e:
except Exception as e: # pylint: disable=broad-except
print(f"Failed to push trace: {e}")
return {"error": str(e)}
+16 -13
View File
@@ -3,7 +3,7 @@
import asyncio
import os
import time
from typing import Any, Dict, List
from typing import Any
from functools import wraps
from fastapi import HTTPException
@@ -339,22 +339,22 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
async def check_guardrails(
messages: List[Dict[str, Any]],
guardrails: List[Guardrail],
messages: list[dict[str, Any]],
guardrails: list[Guardrail],
context: RequestContext,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Checks guardrails on the list of messages.
This calls the batch check API of the Guardrails service.
Args:
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (List[Guardrail]): The guardrails to check against.
messages (list[dict[str, Any]]): List of messages to verify the guardrails against.
guardrails (list[Guardrail]): The guardrails to check against.
invariant_authorization (str): Value of the
invariant-authorization header.
Returns:
Dict: Response containing guardrail check results.
dict: Response containing guardrail check results.
"""
async with httpx.AsyncClient() as client:
url = os.getenv("GUARDRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
@@ -364,11 +364,11 @@ async def check_guardrails(
json={
"messages": messages,
"policies": [g.content for g in guardrails],
"parameters": context.guardrails_parameters or {}
"parameters": context.guardrails_parameters or {},
},
headers={
"Authorization": context.get_guardrailing_authorization(),
"Accept": "application/json"
"Accept": "application/json",
},
timeout=5,
)
@@ -376,11 +376,14 @@ async def check_guardrails(
if result.status_code == 401:
raise HTTPException(
status_code=401,
detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format(
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER
detail=(
"The provided Invariant API key is not valid for guardrail checking. "
"Please ensure you are using the correct API key or pass an "
"alternative API key for guardrail checking specifically via the "
f"'{INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER}' header."
),
)
raise Exception(
raise Exception( # pylint: disable=broad-exception-raised
f"Guardrails check failed: {result.status_code} - {result.text}"
)
guardrails_result = result.json()
@@ -412,7 +415,7 @@ async def check_guardrails(
return aggregated_errors
except HTTPException as e:
raise e
except Exception as e:
except Exception as e: # pylint: disable=broad-except
print(f"Failed to verify guardrails: {e}")
# make sure runtime errors are also visible in e.g. Explorer
return {
-1
View File
@@ -2,7 +2,6 @@
import os
import sys
from builtins import print as builtins_print
os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True)
+102 -107
View File
@@ -7,7 +7,7 @@ import getpass
import os
import random
import socket
from typing import Any, Optional
from typing import Any
from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.append_messages import AppendMessagesRequest
@@ -23,7 +23,6 @@ from gateway.integrations.explorer import (
fetch_guardrails_from_explorer,
)
from gateway.integrations.guardrails import check_guardrails
from gateway.mcp.constants import INVARIANT_SESSION_ID_PREFIX
def user_and_host() -> str:
@@ -33,6 +32,105 @@ def user_and_host() -> str:
return f"{username}@{hostname}"
class McpAttributes(BaseModel):
"""
A Pydantic model to represent MCP attributes.
This can be initialized using HTTP headers for SSE and Streamable transports.
This can also be initialized using CLI arguments for the Stdio transport.
"""
push_explorer: bool
explorer_dataset: str
invariant_api_key: str | None = None
verbose: bool | None = False
metadata: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_request_headers(cls, headers: Headers) -> "McpAttributes":
"""
Create an instance from FastAPI request headers.
Args:
headers: FastAPI Request headers
Returns:
McpAttributes: An instance with values extracted from headers
"""
# Extract and process header values
project_name = headers.get("INVARIANT-PROJECT-NAME")
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
invariant_api_key = headers.get("INVARIANT-API-KEY")
# Determine explorer_dataset
if project_name:
explorer_dataset = project_name
else:
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
# Determine push_explorer
push_explorer = push_explorer_header == "true"
# Create and return instance
return cls(
push_explorer=push_explorer,
explorer_dataset=explorer_dataset,
invariant_api_key=invariant_api_key,
)
@classmethod
def from_cli_args(cls, cli_args: list) -> "McpAttributes":
"""
Create an instance from command line arguments.
Args:
cli_args: List of command line arguments
Returns:
McpAttributes: An instance with values extracted from CLI arguments
"""
parser = argparse.ArgumentParser(description="MCP Gateway")
parser.add_argument(
"--project-name",
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
type=str,
default=f"mcp-capture-{random.randint(1, 100)}",
)
parser.add_argument(
"--push-explorer",
help="Enable pushing traces to Invariant Explorer",
action="store_true",
)
parser.add_argument(
"--verbose",
help="Enable verbose logging",
action="store_true",
)
parser.add_argument(
"--failure-response-format",
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
type=str,
default="error",
)
config, extra_args = parser.parse_known_args(cli_args)
metadata: dict[str, Any] = {}
for arg in extra_args:
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
key, value = arg.split("=")
assert key.startswith(
"--metadata-"
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
key = key[len("--metadata-") :]
metadata[key] = value
return cls(
push_explorer=config.push_explorer,
explorer_dataset=config.project_name,
verbose=config.verbose,
metadata=metadata,
)
class McpSession(BaseModel):
"""
@@ -41,9 +139,9 @@ class McpSession(BaseModel):
session_id: str
messages: list[dict[str, Any]] = Field(default_factory=list)
attributes: Optional["McpAttributes"] = None
attributes: McpAttributes | None = None
id_to_method_mapping: dict[int, str] = Field(default_factory=dict)
trace_id: Optional[str] = None
trace_id: str | None = None
last_trace_length: int = 0
annotations: list[dict[str, Any]] = Field(default_factory=list)
guardrails: GuardrailRuleSet = Field(
@@ -110,9 +208,6 @@ class McpSession(BaseModel):
"system_user": user_and_host(),
**(self.attributes.metadata or {}),
}
metadata["is_stateless_http_server"] = self.session_id.startswith(
INVARIANT_SESSION_ID_PREFIX
)
return metadata
async def get_guardrails_check_result(
@@ -264,106 +359,6 @@ class McpSession(BaseModel):
return messages
class McpAttributes(BaseModel):
"""
A Pydantic model to represent MCP attributes.
This can be initialized using HTTP headers for SSE and Streamable transports.
This can also be initialized using CLI arguments for the Stdio transport.
"""
push_explorer: bool
explorer_dataset: str
invariant_api_key: Optional[str] = None
verbose: Optional[bool] = False
metadata: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_request_headers(cls, headers: Headers) -> "McpAttributes":
"""
Create an instance from FastAPI request headers.
Args:
headers: FastAPI Request headers
Returns:
McpAttributes: An instance with values extracted from headers
"""
# Extract and process header values
project_name = headers.get("INVARIANT-PROJECT-NAME")
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
invariant_api_key = headers.get("INVARIANT-API-KEY")
# Determine explorer_dataset
if project_name:
explorer_dataset = project_name
else:
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
# Determine push_explorer
push_explorer = push_explorer_header == "true"
# Create and return instance
return cls(
push_explorer=push_explorer,
explorer_dataset=explorer_dataset,
invariant_api_key=invariant_api_key,
)
@classmethod
def from_cli_args(cls, cli_args: list) -> "McpAttributes":
"""
Create an instance from command line arguments.
Args:
cli_args: List of command line arguments
Returns:
McpAttributes: An instance with values extracted from CLI arguments
"""
parser = argparse.ArgumentParser(description="MCP Gateway")
parser.add_argument(
"--project-name",
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
type=str,
default=f"mcp-capture-{random.randint(1, 100)}",
)
parser.add_argument(
"--push-explorer",
help="Enable pushing traces to Invariant Explorer",
action="store_true",
)
parser.add_argument(
"--verbose",
help="Enable verbose logging",
action="store_true",
)
parser.add_argument(
"--failure-response-format",
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
type=str,
default="error",
)
config, extra_args = parser.parse_known_args(cli_args)
metadata: dict[str, Any] = {}
for arg in extra_args:
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
key, value = arg.split("=")
assert key.startswith(
"--metadata-"
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
key = key[len("--metadata-") :]
metadata[key] = value
return cls(
push_explorer=config.push_explorer,
explorer_dataset=config.project_name,
verbose=config.verbose,
metadata=metadata,
)
class McpSessionsManager:
"""
A class to manage MCP sessions and their messages.
+50 -49
View File
@@ -8,7 +8,7 @@ import json
import re
import uuid
from abc import ABC, abstractmethod
from typing import Any, Tuple
from typing import Any
from datetime import datetime
from fastapi import Request, HTTPException
@@ -31,7 +31,7 @@ from gateway.mcp.log import format_errors_in_response
from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager
class MCPTransportBase(ABC):
class McpTransportBase(ABC):
"""
Abstract base class for MCP transport strategies.
@@ -44,16 +44,16 @@ class MCPTransportBase(ABC):
async def process_outgoing_request(
self, session_id: str, request_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
) -> tuple[dict[str, Any], bool]:
"""
Template method for processing outgoing requests to MCP server.
Returns:
Tuple[processed_request_data, is_blocked]
tuple[processed_request_data, is_blocked]
"""
# Update session with request information
session = self.session_store.get_session(session_id)
MCPTransportBase.update_session_from_request(session, request_data)
self.update_session_from_request(session, request_data)
# Refresh guardrails
await session.load_guardrails()
@@ -66,19 +66,19 @@ class MCPTransportBase(ABC):
async def process_incoming_response(
self, session_id: str, response_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
) -> tuple[dict[str, Any], bool]:
"""
Template method for processing incoming responses from MCP server.
Returns:
Tuple[processed_response, is_blocked]
tuple[processed_response, is_blocked]
"""
# Update session with server information
session = self.session_store.get_session(session_id)
MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data)
self.update_mcp_server_in_session_metadata(session, response_data)
# Intercept and apply guardrails to response
return await MCPTransportBase.intercept_response(
return await McpTransportBase.intercept_response(
session_id, self.session_store, response_data
)
@@ -87,20 +87,31 @@ class MCPTransportBase(ABC):
method = request_data.get(MCP_METHOD)
return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS]
@staticmethod
def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict:
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": message,
},
}
async def _intercept_outgoing_request(
self, session_id: str, request_data: dict[str, Any]
) -> Tuple[dict[str, Any], bool]:
) -> tuple[dict[str, Any], bool]:
"""Common request interception logic for guardrails."""
method = request_data.get(MCP_METHOD)
interception_result = request_data
is_blocked = False
if method == MCP_TOOL_CALL:
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
interception_result, is_blocked = await self.hook_tool_call(
session_id, self.session_store, request_data
)
elif method == MCP_LIST_TOOLS:
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
interception_result, is_blocked = await self.hook_tool_call(
session_id=session_id,
session_store=self.session_store,
request_body={
@@ -123,7 +134,7 @@ class MCPTransportBase(ABC):
"tool_call_id": f"call_{request_body.get('id')}",
"content": request_body.get(MCP_RESULT, {}).get("content"),
"error": request_body.get(MCP_RESULT, {}).get("error"),
"timestamp": MCPTransportBase.generate_timestamp(),
"timestamp": McpTransportBase.generate_timestamp(),
}
return message
@@ -142,7 +153,7 @@ class MCPTransportBase(ABC):
"role": "assistant",
"content": "",
"tool_calls": [tool_call],
"timestamp": MCPTransportBase.generate_timestamp(),
"timestamp": McpTransportBase.generate_timestamp(),
}
return message
@@ -186,8 +197,8 @@ class MCPTransportBase(ABC):
@staticmethod
def update_session_from_request(session: McpSession, request_body: dict) -> None:
"""Update the MCP client information and request id in the session."""
MCPTransportBase.update_mcp_client_info_in_session(session, request_body)
MCPTransportBase.update_tool_call_id_in_session(session, request_body)
McpTransportBase.update_mcp_client_info_in_session(session, request_body)
McpTransportBase.update_tool_call_id_in_session(session, request_body)
@staticmethod
def get_mcp_server_base_url(request: Request) -> str:
@@ -198,7 +209,7 @@ class MCPTransportBase(ABC):
status_code=400,
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
)
return MCPTransportBase.convert_localhost_to_docker_host(
return McpTransportBase.convert_localhost_to_docker_host(
mcp_server_base_url
).rstrip("/")
@@ -233,7 +244,7 @@ class MCPTransportBase(ABC):
@staticmethod
async def hook_tool_call(
session_id: str, session_store: McpSessionsManager, request_body: dict
) -> Tuple[dict, bool]:
) -> tuple[dict, bool]:
"""
Hook to process the request JSON before sending it to the MCP server.
@@ -243,11 +254,11 @@ class MCPTransportBase(ABC):
request_body (dict): The request JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
indicating whether the request was blocked. If the request is blocked, the
dict will contain an error message else it will contain the original request.
"""
message = MCPTransportBase.generate_request_message(request_body)
message = McpTransportBase.generate_request_message(request_body)
# Check for blocking guardrails
session = session_store.get_session(session_id)
@@ -259,7 +270,7 @@ class MCPTransportBase(ABC):
if (
guardrails_result
and guardrails_result.get("errors", [])
and MCPTransportBase.check_if_new_errors(
and McpTransportBase.check_if_new_errors(
session_id, session_store, guardrails_result
)
):
@@ -269,15 +280,10 @@ class MCPTransportBase(ABC):
message=message,
guardrails_result=guardrails_result,
)
return {
"jsonrpc": "2.0",
"id": request_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}, True
return McpTransportBase._create_jsonrpc_error_response(
request_body,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
), True
# Push trace to the explorer
await session_store.add_message_to_session(
@@ -291,7 +297,7 @@ class MCPTransportBase(ABC):
session_store: McpSessionsManager,
response_body: dict,
is_tools_list=False,
) -> Tuple[dict, bool]:
) -> tuple[dict, bool]:
"""
Hook to process the response JSON after receiving it from the MCP server.
@@ -301,7 +307,7 @@ class MCPTransportBase(ABC):
response_body (dict): The response JSON to be processed.
is_tools_list (bool): Flag to indicate if the response is from a tools/list call.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked. If the response
is blocked, the dict will contain an error message else it will contain the
original response.
@@ -309,7 +315,7 @@ class MCPTransportBase(ABC):
is_blocked = False
result = response_body
message = MCPTransportBase.generate_response_message(result)
message = McpTransportBase.generate_response_message(result)
session = session_store.get_session(session_id)
guardrails_result = await session.get_guardrails_check_result(
@@ -319,22 +325,17 @@ class MCPTransportBase(ABC):
if (
guardrails_result
and guardrails_result.get("errors", [])
and MCPTransportBase.check_if_new_errors(
and McpTransportBase.check_if_new_errors(
session_id, session_store, guardrails_result
)
):
is_blocked = True
if not is_tools_list:
result = {
"jsonrpc": "2.0",
"id": response_body.get("id"),
"error": {
"code": -32600,
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
% guardrails_result["errors"],
},
}
result = McpTransportBase._create_jsonrpc_error_response(
response_body,
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
)
else:
# Special error response for tools/list
result = {
@@ -372,7 +373,7 @@ class MCPTransportBase(ABC):
@staticmethod
async def intercept_response(
session_id: str, session_store: McpSessionsManager, response_body: dict
) -> Tuple[dict, bool]:
) -> tuple[dict, bool]:
"""
Intercept the response and check for guardrails.
This function is used to intercept responses and check for guardrails.
@@ -386,7 +387,7 @@ class MCPTransportBase(ABC):
response_body (dict): The response JSON to be processed.
Returns:
Tuple[dict, bool]: A tuple containing the processed response JSON
tuple[dict, bool]: A tuple containing the processed response JSON
and a boolean indicating whether the response was blocked.
"""
session = session_store.get_session(session_id)
@@ -400,7 +401,7 @@ class MCPTransportBase(ABC):
(
intercept_response_result,
is_blocked,
) = await MCPTransportBase.hook_tool_call_response(
) = await McpTransportBase.hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body=response_body,
@@ -414,7 +415,7 @@ class MCPTransportBase(ABC):
(
intercept_response_result,
is_blocked,
) = await MCPTransportBase.hook_tool_call_response(
) = await McpTransportBase.hook_tool_call_response(
session_id=session_id,
session_store=session_store,
response_body={
@@ -431,9 +432,9 @@ class MCPTransportBase(ABC):
return intercept_response_result, is_blocked
@abstractmethod
async def initialize_session(self, *args, **kwargs) -> str:
async def initialize_session(self, **kwargs) -> str:
"""Initialize a session for this transport type."""
@abstractmethod
async def handle_communication(self, *args, **kwargs) -> Any:
async def handle_communication(self, **kwargs) -> Any:
"""Handle the main communication for this transport."""
+11 -12
View File
@@ -3,20 +3,20 @@
import asyncio
import json
import re
from typing import Any, AsyncGenerator, Optional, Tuple
from typing import Any, AsyncGenerator
import httpx
from httpx_sse import aconnect_sse, ServerSentEvent
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import CLIENT_TIMEOUT
from gateway.common.constants import CLIENT_TIMEOUT, CONTENT_TYPE_EVENT_STREAM
from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8
from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_transport_base import McpTransportBase
MCP_SERVER_POST_HEADERS = {
"connection",
@@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post(
raise HTTPException(status_code=400, detail="Session does not exist")
request_body = json.loads(await request.body())
return await SSETransport(session_store).handle_post_request(
return await SseTransport(session_store).handle_post_request(
request, session_id, request_body
)
@@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream(
request: Request, session_store: McpSessionsManager
) -> StreamingResponse:
"""Integration function for SSE GET route."""
return await SSETransport(session_store).handle_sse_stream(request)
return await SseTransport(session_store).handle_sse_stream(request)
class SSETransport(MCPTransportBase):
class SseTransport(McpTransportBase):
"""
Server-Sent Events transport implementation for MCP communication.
Handles HTTP-based SSE communication with message queuing.
@@ -82,12 +82,11 @@ class SSETransport(MCPTransportBase):
async def initialize_session(
self,
*args,
**kwargs,
) -> str:
"""Initialize or get existing SSE session."""
session_id: Optional[str] = kwargs.get("session_id", None)
session_attributes: Optional[McpAttributes] = kwargs.get(
session_id: str | None = kwargs.get("session_id", None)
session_attributes: McpAttributes | None = kwargs.get(
"session_attributes", None
)
if session_id and self.session_store.session_exists(session_id):
@@ -294,17 +293,17 @@ class SSETransport(MCPTransportBase):
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
media_type=CONTENT_TYPE_EVENT_STREAM,
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
async def handle_communication(self, *args, **kwargs) -> StreamingResponse:
async def handle_communication(self, **kwargs) -> StreamingResponse:
"""Main communication handler for SSE transport."""
return await self.handle_sse_stream(kwargs.get("request"))
async def _handle_endpoint_event(
self, sse: ServerSentEvent, sse_header_attributes: McpAttributes
) -> Tuple[bytes, str]:
) -> tuple[bytes, str]:
"""Handle endpoint event and initialize session if needed."""
match = re.search(r"session_id=([^&\s]+)", sse.data)
session_id = match.group(1) if match else None
+6 -7
View File
@@ -7,7 +7,6 @@ import platform
import select
import subprocess
import sys
from typing import Optional, Tuple
from gateway.mcp.constants import UTF_8
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
@@ -15,7 +14,7 @@ from gateway.mcp.mcp_sessions_manager import (
McpAttributes,
McpSessionsManager,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_transport_base import McpTransportBase
STATUS_EOF = "eof"
STATUS_DATA = "data"
@@ -23,7 +22,7 @@ STATUS_WAIT = "wait"
mcp_sessions_manager = McpSessionsManager()
class StdioTransport(MCPTransportBase):
class StdioTransport(McpTransportBase):
"""
STDIO transport implementation for MCP communication.
Handles subprocess-based communication with stdin/stdout/stderr.
@@ -33,7 +32,7 @@ class StdioTransport(MCPTransportBase):
super().__init__(session_store)
self.mcp_process: subprocess.Popen = None
async def initialize_session(self, *args, **kwargs) -> str:
async def initialize_session(self, **kwargs) -> str:
"""Initialize session for stdio transport."""
session_attributes: McpAttributes = kwargs.get("session_attributes")
session_id = self.generate_session_id()
@@ -53,7 +52,7 @@ class StdioTransport(MCPTransportBase):
mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}")
return self.mcp_process
async def handle_communication(self, *args, **kwargs) -> None:
async def handle_communication(self, **kwargs) -> None:
"""Handle stdio communication loop."""
session_id: str = kwargs.get("session_id")
mcp_process: subprocess.Popen = kwargs.get("mcp_process")
@@ -210,7 +209,7 @@ class StdioTransport(MCPTransportBase):
async def _wait_for_stdin_input(
self, loop: asyncio.AbstractEventLoop, stdin_fd: int
) -> Tuple[Optional[bytes], str]:
) -> tuple[bytes | None, str]:
"""Platform-specific implementation to wait for and read input from stdin."""
if platform.system() == "Windows":
await asyncio.sleep(0.01)
@@ -261,7 +260,7 @@ async def create_stdio_transport_and_execute(
)
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
def split_args(args: list[str] | None = None) -> tuple[list[str], list[str]]:
"""
Splits CLI arguments into two parts:
1. Arguments intended for the MCP gateway (everything before `--exec`)
+23 -21
View File
@@ -1,14 +1,19 @@
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
import json
from typing import Any, Optional, Union
from typing import Any
import httpx
from httpx_sse import aconnect_sse
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from gateway.common.constants import CLIENT_TIMEOUT
from gateway.common.constants import (
CLIENT_TIMEOUT,
CONTENT_TYPE_HEADER,
CONTENT_TYPE_JSON,
CONTENT_TYPE_EVENT_STREAM,
)
from gateway.mcp.constants import (
INVARIANT_SESSION_ID_PREFIX,
MCP_CUSTOM_HEADER_PREFIX,
@@ -18,16 +23,13 @@ from gateway.mcp.mcp_sessions_manager import (
McpSessionsManager,
McpAttributes,
)
from gateway.mcp.mcp_transport_base import MCPTransportBase
from gateway.mcp.mcp_transport_base import McpTransportBase
gateway = APIRouter()
mcp_sessions_manager = McpSessionsManager()
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_SSE = "text/event-stream"
CONTENT_TYPE_HEADER = "content-type"
MCP_SESSION_ID_HEADER = "mcp-session-id"
MCP_SERVER_POST_DELETE_HEADERS = {
MCP_SERVER_POST_AND_DELETE_HEADERS = {
"connection",
"accept",
CONTENT_TYPE_HEADER,
@@ -69,7 +71,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response:
async def create_streamable_transport_and_handle_request(
request: Request, method: str, session_store: McpSessionsManager
) -> Union[Response, StreamingResponse]:
) -> Response | StreamingResponse:
"""Integration function for streamable routes."""
streamable_transport = StreamableTransport(session_store)
return await streamable_transport.handle_communication(
@@ -77,7 +79,7 @@ async def create_streamable_transport_and_handle_request(
)
class StreamableTransport(MCPTransportBase):
class StreamableTransport(McpTransportBase):
"""
Streamable HTTP transport implementation for MCP communication.
Handles HTTP POST/GET/DELETE requests with JSON and streaming responses.
@@ -85,12 +87,11 @@ class StreamableTransport(MCPTransportBase):
async def initialize_session(
self,
*args,
**kwargs,
) -> str:
"""Initialize streamable HTTP session."""
session_id: Optional[str] = kwargs.get("session_id", None)
session_attributes: Optional[McpAttributes] = kwargs.get(
session_id: str | None = kwargs.get("session_id", None)
session_attributes: McpAttributes | None = kwargs.get(
"session_attributes", None
)
is_initialization_request: bool = kwargs.get("is_initialization_request", False)
@@ -111,7 +112,7 @@ class StreamableTransport(MCPTransportBase):
async def handle_post_request(
self, request: Request, request_body: dict[str, Any]
) -> Union[Response, StreamingResponse]:
) -> Response | StreamingResponse:
"""Handle POST request to streamable endpoint."""
session_attributes = McpAttributes.from_request_headers(request.headers)
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
@@ -188,7 +189,7 @@ class StreamableTransport(MCPTransportBase):
return StreamingResponse(
event_generator(),
media_type=CONTENT_TYPE_SSE,
media_type=CONTENT_TYPE_EVENT_STREAM,
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
)
@@ -222,9 +223,7 @@ class StreamableTransport(MCPTransportBase):
print(f"[MCP DELETE] Request error: {str(e)}")
raise HTTPException(status_code=500, detail="Request error") from e
async def handle_communication(
self, *args, **kwargs
) -> Union[Response, StreamingResponse]:
async def handle_communication(self, **kwargs) -> Response | StreamingResponse:
"""Main communication handler for streamable transport."""
request = kwargs.get("request")
method = kwargs.get("method", "POST")
@@ -241,7 +240,7 @@ class StreamableTransport(MCPTransportBase):
async def _process_non_init_request(
self, session_id: str, request_body: dict[str, Any]
) -> Optional[Response]:
) -> Response | None:
"""Process non-initialization requests for guardrails."""
processed_request, is_blocked = await self.process_outgoing_request(
session_id, request_body
@@ -262,7 +261,7 @@ class StreamableTransport(MCPTransportBase):
session_id: str,
session_attributes: McpAttributes,
is_initialization_request: bool,
) -> Union[Response, StreamingResponse]:
) -> Response | StreamingResponse:
"""Forward request to MCP server and handle response."""
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
try:
@@ -385,7 +384,7 @@ class StreamableTransport(MCPTransportBase):
return StreamingResponse(
event_generator(),
media_type=CONTENT_TYPE_SSE,
media_type=CONTENT_TYPE_EVENT_STREAM,
headers=response_headers,
)
@@ -395,6 +394,9 @@ class StreamableTransport(MCPTransportBase):
"""Update MCP response info in session metadata."""
session = self.session_store.get_session(session_id)
self.update_mcp_server_in_session_metadata(session, response_body)
session.attributes.metadata["is_stateless_http_server"] = session_id.startswith(
INVARIANT_SESSION_ID_PREFIX
)
session.attributes.metadata["server_response_type"] = (
"json" if is_json_response else "sse"
)
@@ -405,7 +407,7 @@ class StreamableTransport(MCPTransportBase):
for k, v in request.headers.items():
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not (
if k.lower() in MCP_SERVER_POST_AND_DELETE_HEADERS and not (
k.lower() == MCP_SESSION_ID_HEADER
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
):
-42
View File
@@ -1,42 +0,0 @@
"""Task utilities for running async functions"""
import asyncio
import concurrent.futures
from contextlib import redirect_stdout
from typing import Any
from gateway.mcp.log import MCP_LOG_FILE
def run_task_sync(async_func, *args, **kwargs) -> Any:
"""
Runs an asynchronous function synchronously in a separate
thread with its own event loop. This function blocks the calling
thread until completion or timeout (10 seconds).
Args:
async_func: The async function to run
*args: Positional arguments to pass to the async function
**kwargs: Keyword arguments to pass to the async function
Returns:
Any: The return value of the async function
"""
def run_in_new_loop():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(
async_func(
*args,
**kwargs,
)
)
finally:
loop.close()
with redirect_stdout(MCP_LOG_FILE):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
return future.result(timeout=10.0)
+33 -16
View File
@@ -14,7 +14,12 @@ from gateway.common.config_manager import (
GatewayConfigManager,
extract_guardrails_from_header,
)
from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
from gateway.common.constants import (
CLIENT_TIMEOUT,
CONTENT_TYPE_JSON,
CONTENT_TYPE_EVENT_STREAM,
IGNORED_HEADERS,
)
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
from gateway.common.request_context import RequestContext
from gateway.converters.anthropic_to_invariant import (
@@ -64,7 +69,7 @@ def validate_headers(x_api_key: str = Header(None)):
)
async def anthropic_v1_messages_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
):
@@ -162,7 +167,7 @@ async def get_guardrails_check_result(
async def push_to_explorer(
context: RequestContext,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
guardrails_execution_result: dict | None = None,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
guardrails_execution_result = guardrails_execution_result or {}
@@ -210,15 +215,18 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
self.anthropic_request: httpx.Request = anthropic_request
# response data
self.response: Optional[httpx.Response] = None
self.response_string: Optional[str] = None
self.response_json: Optional[dict[str, Any]] = None
self.response: httpx.Response | None = None
self.response_string: str | None = None
self.response_json: dict[str, Any] | None = None
# guardrailing response (if any)
self.guardrails_execution_result = {}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
"""
Check guardrails in a pipelined fashion, before processing the first
chunk (for input guardrailing).
"""
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, action=GuardrailAction.BLOCK, response_json={}
@@ -249,8 +257,8 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
Response(
content=error_chunk,
status_code=400,
media_type="application/json",
headers={"content-type": "application/json"},
media_type=CONTENT_TYPE_JSON,
headers={"content-type": CONTENT_TYPE_JSON},
)
)
@@ -263,7 +271,10 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}",
detail=(
"Invalid JSON response received from Anthropic: "
f"{self.response.text}, got error: {e}"
),
) from e
if self.response.status_code != 200:
raise HTTPException(
@@ -289,12 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
return Response(
content=content,
status_code=status_code,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
headers=dict(updated_headers),
)
async def on_end(self):
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
"""
Checks guardrails after the response is received, and asynchronously
pushes to Explorer.
"""
# ensure the response data is available
assert self.response is not None, "response is None"
assert self.response_json is not None, "response_json is None"
@@ -383,7 +397,10 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
self.sse_buffer = "" # Buffer for incomplete events
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
"""
Check guardrails in a pipelined fashion, before processing the
first chunk (for input guardrailing).
"""
if self.context.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context,
@@ -503,7 +520,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...",
flush=True,
)
except Exception as e:
except Exception as e: # pylint: disable=broad-except
print(f"Error processing event: {e}", flush=True)
# on last stream chunk, run output guardrails
@@ -536,7 +553,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
"""Process the buffer and extract complete SSE events.
Returns:
Tuple[List[str], str]: A tuple containing a list of
tuple[list[str], str]: A tuple containing a list of
complete events and the remaining buffer with incomplete events.
"""
# Split on double newlines which separate SSE events
@@ -582,7 +599,7 @@ async def handle_streaming_response(
)
return StreamingResponse(
response.instrumented_event_generator(), media_type="text/event-stream"
response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM
)
+23 -15
View File
@@ -2,7 +2,7 @@
import asyncio
import json
from typing import Any, Literal, Optional
from typing import Any, Literal
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
@@ -16,6 +16,8 @@ from gateway.common.config_manager import (
)
from gateway.common.constants import (
CLIENT_TIMEOUT,
CONTENT_TYPE_JSON,
CONTENT_TYPE_EVENT_STREAM,
IGNORED_HEADERS,
)
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
@@ -47,7 +49,7 @@ async def gemini_generate_content_gateway(
api_version: str,
model: str,
endpoint: str,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
alt: str = Query(
None, title="Response Format", description="Set to 'sse' for streaming"
),
@@ -58,8 +60,10 @@ async def gemini_generate_content_gateway(
if endpoint not in ["generateContent", "streamGenerateContent"]:
return Response(
content="Invalid endpoint - the only endpoints supported are: \
/api/v1/gateway/gemini/<version>/models/<model-name>:generateContent or \
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:generateContent",
/api/v1/gateway/gemini/<version>/models/<model-name>:generateContent \
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:generateContent \
/api/v1/gateway/gemini/<version>/models/<model-name>:streamGenerateContent or \
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:streamGenerateContent",
status_code=400,
)
headers = {
@@ -80,7 +84,11 @@ async def gemini_generate_content_gateway(
request_json = json.loads(request_body_bytes)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}"
gemini_api_url = (
f"https://generativelanguage.googleapis.com/"
f"{api_version}/models/"
f"{model}:{endpoint}"
)
if alt == "sse":
gemini_api_url += "?alt=sse"
gemini_request = client.build_request(
@@ -139,7 +147,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
}
# guardrailing execution result (if any)
self.guardrails_execution_result: Optional[dict[str, Any]] = None
self.guardrails_execution_result: dict[str, Any] | None = None
def make_refusal(
self,
@@ -301,7 +309,7 @@ async def stream_response(
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
media_type=CONTENT_TYPE_EVENT_STREAM,
)
@@ -407,7 +415,7 @@ async def get_guardrails_check_result(
async def push_to_explorer(
context: RequestContext,
response_json: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
guardrails_execution_result: dict | None = None,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
guardrails_execution_result = guardrails_execution_result or {}
@@ -456,11 +464,11 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
self.gemini_request: httpx.Request = gemini_request
# response data
self.response: Optional[httpx.Response] = None
self.response_json: Optional[dict[str, Any]] = None
self.response: httpx.Response | None = None
self.response_json: dict[str, Any] | None = None
# guardrails execution result (if any)
self.guardrails_execution_result: Optional[dict[str, Any]] = None
self.guardrails_execution_result: dict[str, Any] | None = None
async def on_start(self):
"""
@@ -509,9 +517,9 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
Response(
content=error_chunk,
status_code=400,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
headers={
"Content-Type": "application/json",
"Content-Type": CONTENT_TYPE_JSON,
},
)
)
@@ -539,7 +547,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
headers=dict(self.response.headers),
)
@@ -582,7 +590,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
Response(
content=response_string,
status_code=response_code,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
headers=dict(self.response.headers),
)
)
+18 -15
View File
@@ -2,7 +2,7 @@
import asyncio
import json
from typing import Any, Optional
from typing import Any
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
@@ -16,6 +16,8 @@ from gateway.common.config_manager import (
)
from gateway.common.constants import (
CLIENT_TIMEOUT,
CONTENT_TYPE_JSON,
CONTENT_TYPE_EVENT_STREAM,
IGNORED_HEADERS,
)
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
@@ -60,14 +62,14 @@ def make_cors_response(request: Request, allow_methods: str) -> Response:
@gateway.options("/{dataset_name}/openai/chat/completions")
@gateway.options("/openai/chat/completions")
async def openai_chat_completions_options(request: Request, dataset_name: str = None):
async def openai_chat_completions_options(request: Request):
"""Enables CORS for the OpenAI chat completions endpoint"""
return make_cors_response(request, allow_methods="POST")
@gateway.options("/{dataset_name}/openai/models")
@gateway.options("/openai/models")
async def openai_models_options(request: Request, dataset_name: str = None):
async def openai_models_options(request: Request):
"""Enables CORS for the OpenAI models endpoint"""
return make_cors_response(request, allow_methods="GET")
@@ -76,7 +78,7 @@ async def openai_models_options(request: Request, dataset_name: str = None):
@gateway.get("/openai/models")
async def openai_models_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
):
"""Proxy request to OpenAI /models endpoint"""
headers = {
@@ -110,7 +112,7 @@ async def openai_models_gateway(
)
async def openai_chat_completions_gateway(
request: Request,
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
) -> Response:
@@ -180,7 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
self.open_ai_request: httpx.Request = open_ai_request
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
self.guardrails_execution_result: dict | None = None
# merged_response will be updated with the data from the chunks in the stream
# At the end of the stream, this will be sent to the explorer
@@ -273,7 +275,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
}
)
# yield an extra error chunk (without preventing the original chunk to go through after)
# yield an extra error chunk (without preventing the original
# chunk to go through after)
return ExtraItem(f"data: {error_chunk}\n\n".encode())
# push will happen in on_end
@@ -324,7 +327,7 @@ async def handle_stream_response(
)
return StreamingResponse(
response.instrumented_event_generator(), media_type="text/event-stream"
response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM
)
@@ -483,7 +486,7 @@ def create_metadata(
async def push_to_explorer(
context: RequestContext,
merged_response: dict[str, Any],
guardrails_execution_result: Optional[dict] = None,
guardrails_execution_result: dict | None = None,
) -> None:
"""Pushes the merged response to the Invariant Explorer"""
# Only push the trace to explorer if the message is an end turn message
@@ -569,11 +572,11 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
self.open_ai_request: httpx.Request = open_ai_request
# request outputs
self.response: Optional[httpx.Response] = None
self.response_json: Optional[dict[str, Any]] = None
self.response: httpx.Response | None = None
self.response_json: dict[str, Any] | None = None
# guardrailing output (if any)
self.guardrails_execution_result: Optional[dict] = None
self.guardrails_execution_result: dict | None = None
async def on_start(self):
"""
@@ -606,7 +609,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
}
),
status_code=400,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
),
end_of_stream=True,
)
@@ -634,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
headers=dict(self.response.headers),
)
@@ -686,7 +689,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
Response(
content=response_string,
status_code=response_code,
media_type="application/json",
media_type=CONTENT_TYPE_JSON,
),
)
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "invariant-gateway"
version = "0.0.5.2"
version = "0.0.6"
description = "LLM proxy to observe and debug what your AI agents are doing"
readme = "README.md"
requires-python = ">=3.12"
@@ -7,7 +7,6 @@ import sys
import time
import uuid
from pathlib import Path
from typing import Dict, List
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -49,7 +48,7 @@ class WeatherAgent:
},
}
def get_response(self, messages: List[Dict]) -> List[Dict]:
def get_response(self, messages: list[dict]) -> list[dict]:
"""
Get the response from the agent for a given user query for weather.
"""
@@ -83,7 +82,7 @@ class WeatherAgent:
else:
return response_list
def get_streaming_response(self, messages: List[Dict]) -> List[Dict]:
def get_streaming_response(self, messages: list[dict]) -> list[dict]:
"""Get streaming response from the agent for a given user query for weather."""
response_list = []
+1 -3
View File
@@ -2,7 +2,6 @@
import os
import uuid
from resources.mcp.sse.client.main import run as mcp_sse_client_run
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
from resources.mcp.streamable.client.main import run as mcp_streamable_client_run
@@ -12,7 +11,6 @@ import httpx
import pytest
import requests
from datetime import datetime
from mcp.shared.exceptions import McpError
# Taken from docker-compose.test.yml
MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
@@ -706,7 +704,7 @@ async def test_mcp_message_timestamps(
"""Test that MCP messages include timestamps"""
project_name = "test-mcp-" + str(uuid.uuid4())
# Run the MCP client and make the tool call
result = await _invoke_mcp_tool(
transport,
gateway_url,
@@ -11,7 +11,7 @@ async def run(
gateway_url: str,
tool_name: str,
tool_args: dict[str, Any],
headers: dict[str, str] = None,
headers: dict[str, str] | None = None,
) -> types.CallToolResult | types.ListToolsResult:
"""
Run the MCP client with the given parameters.
@@ -3,7 +3,7 @@
import os
from datetime import timedelta
from typing import Any, Optional
from typing import Any
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
@@ -14,7 +14,7 @@ def _get_server_params(
project_name: str,
server_script_path: str,
push_to_explorer: bool,
metadata_keys: Optional[dict[str, str]] = None,
metadata_keys: dict[str, str] | None = None,
) -> StdioServerParameters:
args = [
"--from",
@@ -59,7 +59,7 @@ async def run(
push_to_explorer: bool,
tool_name: str,
tool_args: dict[str, Any],
metadata_keys: Optional[dict[str, str]] = None,
metadata_keys: dict[str, str] | None = None,
) -> types.CallToolResult | types.ListToolsResult:
"""
Main function to setup the MCP client and server.
@@ -12,7 +12,7 @@ async def run(
gateway_url: str,
tool_name: str,
tool_args: dict[str, Any],
headers: dict[str, str] = None,
headers: dict[str, str] | None = None,
) -> types.CallToolResult | types.ListToolsResult:
"""
Run the MCP client with the given parameters.
+4 -4
View File
@@ -2,7 +2,7 @@
import os
import uuid
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal
from httpx import Client
from openai import OpenAI
@@ -62,8 +62,8 @@ def get_gemini_client(
async def create_dataset(
explorer_api_url: str,
invariant_authorization: str,
dataset_name: Optional[str] = None,
) -> Dict[str, Any]:
dataset_name: str | None = None,
) -> dict[str, Any]:
"""Create a dataset in the Explorer API."""
client = Client(base_url=explorer_api_url)
response = client.post(
@@ -85,7 +85,7 @@ async def add_guardrail_to_dataset(
policy: str,
action: Literal["block", "log"],
invariant_authorization: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Add a guardrail to a dataset."""
client = Client(base_url=explorer_api_url)
response = client.post(