diff --git a/gateway/__main__.py b/gateway/__main__.py index de26a66..f994253 100644 --- a/gateway/__main__.py +++ b/gateway/__main__.py @@ -7,8 +7,6 @@ import subprocess import sys import time -from typing import Optional - from gateway.mcp import stdio as mcp_stdio from gateway.mcp.log import mcp_log @@ -64,7 +62,7 @@ def ensure_network_exists(network_name: str = "invariant-explorer-web") -> bool: return False -def setup_guardrails(guardrails_file_path: Optional[str] = None) -> bool: +def setup_guardrails(guardrails_file_path: str | None = None) -> bool: """Configure guardrails if specified.""" if not guardrails_file_path: return True @@ -105,7 +103,7 @@ def build(): return False -def up(guardrails_file_path: Optional[str] = None): +def up(guardrails_file_path: str | None = None): """Set up the local server for the Invariant Gateway.""" # Ensure network exists if not ensure_network_exists(): diff --git a/gateway/common/authorization.py b/gateway/common/authorization.py index 124b8d3..b713950 100644 --- a/gateway/common/authorization.py +++ b/gateway/common/authorization.py @@ -1,6 +1,5 @@ """Common Authorization functions used in the gateway.""" -from typing import Tuple, Optional from fastapi import HTTPException, Request INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization" @@ -10,7 +9,7 @@ API_KEYS_SEPARATOR = ";invariant-auth=" def extract_guardrail_service_authorization_from_headers( request: Request, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: """ Extracts the optional Invariant-Guardrails-Authorization authorization header from the request. @@ -22,10 +21,10 @@ def extract_guardrail_service_authorization_from_headers( def extract_authorization_from_headers( request: Request, - dataset_name: Optional[str] = None, - llm_provider_api_key_header: Optional[str] = None, - llm_provider_fallback_api_key_headers: Optional[list[str]] = None, -) -> Tuple[Optional[str], Optional[str]]: + dataset_name: str | None = None, + llm_provider_api_key_header: str | None = None, + llm_provider_fallback_api_key_headers: list[str] | None = None, +) -> tuple[str | None, str | None]: """ Extracts the Invariant authorization and LLM Provider API key from the request headers. diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 98d83e5..af2f310 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -3,7 +3,6 @@ import asyncio import os import threading -from typing import Optional import fastapi from httpx import HTTPStatusError @@ -11,7 +10,7 @@ from httpx import HTTPStatusError from gateway.common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet -def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]: +def extract_policy_from_headers(request: fastapi.Request | None) -> str | None: """ Extracts the guardrailing policy from the request headers if present. @@ -79,7 +78,7 @@ class GatewayConfigManager: _lock = threading.Lock() @classmethod - def get_config(cls, request: fastapi.Request = None) -> GatewayConfig: + def get_config(cls) -> GatewayConfig: """Initializes and returns the gateway configuration using double-checked locking.""" local_config = cls._config_instance @@ -95,7 +94,7 @@ class GatewayConfigManager: async def extract_guardrails_from_header( request: fastapi.Request, -) -> Optional[GuardrailRuleSet]: +) -> GuardrailRuleSet | None: """ Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding GuardrailRuleSet. If no guardrails are provided, returns None. @@ -115,3 +114,4 @@ async def extract_guardrails_from_header( blocking_guardrails=guardrails, logging_guardrails=[], ) + return None diff --git a/gateway/common/constants.py b/gateway/common/constants.py index a84eaea..20ab1d8 100644 --- a/gateway/common/constants.py +++ b/gateway/common/constants.py @@ -16,3 +16,6 @@ IGNORED_HEADERS = [ CLIENT_TIMEOUT = 60.0 +CONTENT_TYPE_HEADER = "content-type" +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_EVENT_STREAM = "text/event-stream" diff --git a/gateway/common/guardrails.py b/gateway/common/guardrails.py index e4164c1..73c35f7 100644 --- a/gateway/common/guardrails.py +++ b/gateway/common/guardrails.py @@ -1,10 +1,7 @@ """Common guardrails data class.""" -from enum import Enum -from typing import List - from dataclasses import dataclass - +from enum import Enum class GuardrailAction(str, Enum): """Enum representing the action to be taken for guardrail rules.""" @@ -27,5 +24,5 @@ class Guardrail: class GuardrailRuleSet: """Grouped guardrail rules separated by their action.""" - blocking_guardrails: List[Guardrail] - logging_guardrails: List[Guardrail] + blocking_guardrails: list[Guardrail] + logging_guardrails: list[Guardrail] diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py index 11b477d..2f6afa0 100644 --- a/gateway/common/request_context.py +++ b/gateway/common/request_context.py @@ -1,33 +1,33 @@ """Common Request context data class.""" from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any import fastapi -from gateway.common.config_manager import GatewayConfig -from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from gateway.common.authorization import ( extract_guardrail_service_authorization_from_headers, ) +from gateway.common.config_manager import GatewayConfig +from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction @dataclass(frozen=True) class RequestContext: """Structured context for a request. Must be created via `RequestContext.create()`.""" - request_json: Dict[str, Any] - dataset_name: Optional[str] = None + request_json: dict[str, Any] + dataset_name: str | None = None # authorization to use for invariant service like explorer - invariant_authorization: Optional[str] = None + invariant_authorization: str | None = None # authorization to use for invariant guardrailing specifically - guardrail_authorization: Optional[str] = None + guardrail_authorization: str | None = None # the set of guardrails to enforce for this request - guardrails: Optional[GuardrailRuleSet] = None - config: Dict[str, Any] = None - + guardrails: GuardrailRuleSet | None = None + config: dict[str, Any] | None = None + # extra parameters available as input. during guardrail evaluation - guardrails_parameters: Optional[Dict[str, Any]] = None + guardrails_parameters: dict[str, Any] | None = None _created_via_factory: bool = field( default=False, init=True, repr=False, compare=False @@ -42,13 +42,13 @@ class RequestContext: @classmethod def create( cls, - request_json: Dict[str, Any], - dataset_name: Optional[str] = None, - invariant_authorization: Optional[str] = None, - guardrails: Optional[GuardrailRuleSet] = None, - config: Optional[GatewayConfig] = None, - request: fastapi.Request = None, - guardrails_parameters: Optional[Dict[str, Any]] = None, + request_json: dict[str, Any], + dataset_name: str | None = None, + invariant_authorization: str | None = None, + guardrails: GuardrailRuleSet | None = None, + config: GatewayConfig | None = None, + request: fastapi.Request | None = None, + guardrails_parameters: dict[str, Any] | None = None, ) -> "RequestContext": """Creates a new RequestContext instance, applying default guardrails if needed.""" @@ -100,10 +100,10 @@ class RequestContext: guardrails=guardrails, config=context_config, _created_via_factory=True, - guardrails_parameters=guardrails_parameters + guardrails_parameters=guardrails_parameters, ) - def get_guardrailing_authorization(self) -> Optional[str]: + def get_guardrailing_authorization(self) -> str | None: """ Returns the authorization to use for the guardrailing service. diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index e547831..007ccc0 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -2,8 +2,9 @@ import os import json +from typing import Any -from typing import Any, Dict, List +import httpx from fastapi import HTTPException from gateway.common.constants import DEFAULT_API_URL @@ -12,12 +13,10 @@ from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse from invariant_sdk.types.annotations import AnnotationCreate -import httpx - def create_annotations_from_guardrails_errors( - guardrails_errors: List[dict], -) -> List[AnnotationCreate]: + guardrails_errors: list[dict], +) -> list[AnnotationCreate]: """Create Explorer annotations from the guardrails errors.""" annotations = [] @@ -68,7 +67,7 @@ def create_annotations_from_guardrails_errors( return remove_duplicates(annotations) -def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]: +def remove_duplicates(annotations: list[AnnotationCreate]) -> list[AnnotationCreate]: """ Remove duplicate annotations based on content, address, and extra_metadata. @@ -99,18 +98,18 @@ def get_explorer_api_url() -> str: async def push_trace( - messages: List[List[Dict[str, Any]]], + messages: list[list[dict[str, Any]]], dataset_name: str, invariant_authorization: str, - annotations: List[List[AnnotationCreate]] = None, - metadata: List[Dict[str, Any]] = None, + annotations: list[list[AnnotationCreate]] | None = None, + metadata: list[dict[str, Any]] | None = None, ) -> PushTracesResponse: """Pushes traces to the dataset on the Invariant Explorer. If a dataset with the given name does not exist, it will be created. Args: - messages (List[List[Dict[str, Any]]]): List of messages to push. + messages (listlistdict[str, Any]]]): List of messages to push. dataset_name (str): Name of the dataset. invariant_authorization (str): Value of the invariant-authorization header. @@ -135,7 +134,7 @@ async def push_trace( ) try: return await client.push_trace(request) - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Failed to push trace: {e}") return {"error": str(e)} diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 3424881..7b2a28f 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -3,7 +3,7 @@ import asyncio import os import time -from typing import Any, Dict, List +from typing import Any from functools import wraps from fastapi import HTTPException @@ -339,22 +339,22 @@ class InstrumentedResponse(InstrumentedStreamingResponse): async def check_guardrails( - messages: List[Dict[str, Any]], - guardrails: List[Guardrail], + messages: list[dict[str, Any]], + guardrails: list[Guardrail], context: RequestContext, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Checks guardrails on the list of messages. This calls the batch check API of the Guardrails service. Args: - messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. - guardrails (List[Guardrail]): The guardrails to check against. + messages (list[dict[str, Any]]): List of messages to verify the guardrails against. + guardrails (list[Guardrail]): The guardrails to check against. invariant_authorization (str): Value of the invariant-authorization header. Returns: - Dict: Response containing guardrail check results. + dict: Response containing guardrail check results. """ async with httpx.AsyncClient() as client: url = os.getenv("GUARDRAILS_API_URL", DEFAULT_API_URL).rstrip("/") @@ -364,11 +364,11 @@ async def check_guardrails( json={ "messages": messages, "policies": [g.content for g in guardrails], - "parameters": context.guardrails_parameters or {} + "parameters": context.guardrails_parameters or {}, }, headers={ "Authorization": context.get_guardrailing_authorization(), - "Accept": "application/json" + "Accept": "application/json", }, timeout=5, ) @@ -376,11 +376,14 @@ async def check_guardrails( if result.status_code == 401: raise HTTPException( status_code=401, - detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format( - INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER + detail=( + "The provided Invariant API key is not valid for guardrail checking. " + "Please ensure you are using the correct API key or pass an " + "alternative API key for guardrail checking specifically via the " + f"'{INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER}' header." ), ) - raise Exception( + raise Exception( # pylint: disable=broad-exception-raised f"Guardrails check failed: {result.status_code} - {result.text}" ) guardrails_result = result.json() @@ -412,7 +415,7 @@ async def check_guardrails( return aggregated_errors except HTTPException as e: raise e - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Failed to verify guardrails: {e}") # make sure runtime errors are also visible in e.g. Explorer return { diff --git a/gateway/mcp/log.py b/gateway/mcp/log.py index 7287513..388355b 100644 --- a/gateway/mcp/log.py +++ b/gateway/mcp/log.py @@ -2,7 +2,6 @@ import os import sys - from builtins import print as builtins_print os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True) diff --git a/gateway/mcp/mcp_sessions_manager.py b/gateway/mcp/mcp_sessions_manager.py index bb915d2..ece6a3e 100644 --- a/gateway/mcp/mcp_sessions_manager.py +++ b/gateway/mcp/mcp_sessions_manager.py @@ -7,7 +7,7 @@ import getpass import os import random import socket -from typing import Any, Optional +from typing import Any from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.append_messages import AppendMessagesRequest @@ -23,7 +23,6 @@ from gateway.integrations.explorer import ( fetch_guardrails_from_explorer, ) from gateway.integrations.guardrails import check_guardrails -from gateway.mcp.constants import INVARIANT_SESSION_ID_PREFIX def user_and_host() -> str: @@ -33,6 +32,105 @@ def user_and_host() -> str: return f"{username}@{hostname}" +class McpAttributes(BaseModel): + """ + A Pydantic model to represent MCP attributes. + This can be initialized using HTTP headers for SSE and Streamable transports. + This can also be initialized using CLI arguments for the Stdio transport. + """ + + push_explorer: bool + explorer_dataset: str + invariant_api_key: str | None = None + verbose: bool | None = False + metadata: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_request_headers(cls, headers: Headers) -> "McpAttributes": + """ + Create an instance from FastAPI request headers. + + Args: + headers: FastAPI Request headers + + Returns: + McpAttributes: An instance with values extracted from headers + """ + # Extract and process header values + project_name = headers.get("INVARIANT-PROJECT-NAME") + push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower() + invariant_api_key = headers.get("INVARIANT-API-KEY") + + # Determine explorer_dataset + if project_name: + explorer_dataset = project_name + else: + explorer_dataset = f"mcp-capture-{random.randint(1, 100)}" + + # Determine push_explorer + push_explorer = push_explorer_header == "true" + + # Create and return instance + return cls( + push_explorer=push_explorer, + explorer_dataset=explorer_dataset, + invariant_api_key=invariant_api_key, + ) + + @classmethod + def from_cli_args(cls, cli_args: list) -> "McpAttributes": + """ + Create an instance from command line arguments. + + Args: + cli_args: List of command line arguments + + Returns: + McpAttributes: An instance with values extracted from CLI arguments + """ + parser = argparse.ArgumentParser(description="MCP Gateway") + parser.add_argument( + "--project-name", + help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.", + type=str, + default=f"mcp-capture-{random.randint(1, 100)}", + ) + parser.add_argument( + "--push-explorer", + help="Enable pushing traces to Invariant Explorer", + action="store_true", + ) + parser.add_argument( + "--verbose", + help="Enable verbose logging", + action="store_true", + ) + parser.add_argument( + "--failure-response-format", + help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)", + type=str, + default="error", + ) + + config, extra_args = parser.parse_known_args(cli_args) + + metadata: dict[str, Any] = {} + for arg in extra_args: + assert "=" in arg, f"Invalid extra metadata argument: {arg}" + key, value = arg.split("=") + assert key.startswith( + "--metadata-" + ), f"Invalid extra metadata argument: {arg}, must start with --metadata-" + key = key[len("--metadata-") :] + metadata[key] = value + + return cls( + push_explorer=config.push_explorer, + explorer_dataset=config.project_name, + verbose=config.verbose, + metadata=metadata, + ) + class McpSession(BaseModel): """ @@ -41,9 +139,9 @@ class McpSession(BaseModel): session_id: str messages: list[dict[str, Any]] = Field(default_factory=list) - attributes: Optional["McpAttributes"] = None + attributes: McpAttributes | None = None id_to_method_mapping: dict[int, str] = Field(default_factory=dict) - trace_id: Optional[str] = None + trace_id: str | None = None last_trace_length: int = 0 annotations: list[dict[str, Any]] = Field(default_factory=list) guardrails: GuardrailRuleSet = Field( @@ -110,9 +208,6 @@ class McpSession(BaseModel): "system_user": user_and_host(), **(self.attributes.metadata or {}), } - metadata["is_stateless_http_server"] = self.session_id.startswith( - INVARIANT_SESSION_ID_PREFIX - ) return metadata async def get_guardrails_check_result( @@ -264,106 +359,6 @@ class McpSession(BaseModel): return messages -class McpAttributes(BaseModel): - """ - A Pydantic model to represent MCP attributes. - This can be initialized using HTTP headers for SSE and Streamable transports. - This can also be initialized using CLI arguments for the Stdio transport. - """ - - push_explorer: bool - explorer_dataset: str - invariant_api_key: Optional[str] = None - verbose: Optional[bool] = False - metadata: dict[str, Any] = Field(default_factory=dict) - - @classmethod - def from_request_headers(cls, headers: Headers) -> "McpAttributes": - """ - Create an instance from FastAPI request headers. - - Args: - headers: FastAPI Request headers - - Returns: - McpAttributes: An instance with values extracted from headers - """ - # Extract and process header values - project_name = headers.get("INVARIANT-PROJECT-NAME") - push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower() - invariant_api_key = headers.get("INVARIANT-API-KEY") - - # Determine explorer_dataset - if project_name: - explorer_dataset = project_name - else: - explorer_dataset = f"mcp-capture-{random.randint(1, 100)}" - - # Determine push_explorer - push_explorer = push_explorer_header == "true" - - # Create and return instance - return cls( - push_explorer=push_explorer, - explorer_dataset=explorer_dataset, - invariant_api_key=invariant_api_key, - ) - - @classmethod - def from_cli_args(cls, cli_args: list) -> "McpAttributes": - """ - Create an instance from command line arguments. - - Args: - cli_args: List of command line arguments - - Returns: - McpAttributes: An instance with values extracted from CLI arguments - """ - parser = argparse.ArgumentParser(description="MCP Gateway") - parser.add_argument( - "--project-name", - help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.", - type=str, - default=f"mcp-capture-{random.randint(1, 100)}", - ) - parser.add_argument( - "--push-explorer", - help="Enable pushing traces to Invariant Explorer", - action="store_true", - ) - parser.add_argument( - "--verbose", - help="Enable verbose logging", - action="store_true", - ) - parser.add_argument( - "--failure-response-format", - help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)", - type=str, - default="error", - ) - - config, extra_args = parser.parse_known_args(cli_args) - - metadata: dict[str, Any] = {} - for arg in extra_args: - assert "=" in arg, f"Invalid extra metadata argument: {arg}" - key, value = arg.split("=") - assert key.startswith( - "--metadata-" - ), f"Invalid extra metadata argument: {arg}, must start with --metadata-" - key = key[len("--metadata-") :] - metadata[key] = value - - return cls( - push_explorer=config.push_explorer, - explorer_dataset=config.project_name, - verbose=config.verbose, - metadata=metadata, - ) - - class McpSessionsManager: """ A class to manage MCP sessions and their messages. diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index d9c726f..241ad63 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -8,7 +8,7 @@ import json import re import uuid from abc import ABC, abstractmethod -from typing import Any, Tuple +from typing import Any from datetime import datetime from fastapi import Request, HTTPException @@ -31,7 +31,7 @@ from gateway.mcp.log import format_errors_in_response from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager -class MCPTransportBase(ABC): +class McpTransportBase(ABC): """ Abstract base class for MCP transport strategies. @@ -44,16 +44,16 @@ class MCPTransportBase(ABC): async def process_outgoing_request( self, session_id: str, request_data: dict[str, Any] - ) -> Tuple[dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """ Template method for processing outgoing requests to MCP server. Returns: - Tuple[processed_request_data, is_blocked] + tuple[processed_request_data, is_blocked] """ # Update session with request information session = self.session_store.get_session(session_id) - MCPTransportBase.update_session_from_request(session, request_data) + self.update_session_from_request(session, request_data) # Refresh guardrails await session.load_guardrails() @@ -66,19 +66,19 @@ class MCPTransportBase(ABC): async def process_incoming_response( self, session_id: str, response_data: dict[str, Any] - ) -> Tuple[dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """ Template method for processing incoming responses from MCP server. Returns: - Tuple[processed_response, is_blocked] + tuple[processed_response, is_blocked] """ # Update session with server information session = self.session_store.get_session(session_id) - MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data) + self.update_mcp_server_in_session_metadata(session, response_data) # Intercept and apply guardrails to response - return await MCPTransportBase.intercept_response( + return await McpTransportBase.intercept_response( session_id, self.session_store, response_data ) @@ -87,20 +87,31 @@ class MCPTransportBase(ABC): method = request_data.get(MCP_METHOD) return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS] + @staticmethod + def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict: + return { + "jsonrpc": "2.0", + "id": request_body.get("id"), + "error": { + "code": -32600, + "message": message, + }, + } + async def _intercept_outgoing_request( self, session_id: str, request_data: dict[str, Any] - ) -> Tuple[dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """Common request interception logic for guardrails.""" method = request_data.get(MCP_METHOD) interception_result = request_data is_blocked = False if method == MCP_TOOL_CALL: - interception_result, is_blocked = await MCPTransportBase.hook_tool_call( + interception_result, is_blocked = await self.hook_tool_call( session_id, self.session_store, request_data ) elif method == MCP_LIST_TOOLS: - interception_result, is_blocked = await MCPTransportBase.hook_tool_call( + interception_result, is_blocked = await self.hook_tool_call( session_id=session_id, session_store=self.session_store, request_body={ @@ -123,7 +134,7 @@ class MCPTransportBase(ABC): "tool_call_id": f"call_{request_body.get('id')}", "content": request_body.get(MCP_RESULT, {}).get("content"), "error": request_body.get(MCP_RESULT, {}).get("error"), - "timestamp": MCPTransportBase.generate_timestamp(), + "timestamp": McpTransportBase.generate_timestamp(), } return message @@ -142,7 +153,7 @@ class MCPTransportBase(ABC): "role": "assistant", "content": "", "tool_calls": [tool_call], - "timestamp": MCPTransportBase.generate_timestamp(), + "timestamp": McpTransportBase.generate_timestamp(), } return message @@ -186,8 +197,8 @@ class MCPTransportBase(ABC): @staticmethod def update_session_from_request(session: McpSession, request_body: dict) -> None: """Update the MCP client information and request id in the session.""" - MCPTransportBase.update_mcp_client_info_in_session(session, request_body) - MCPTransportBase.update_tool_call_id_in_session(session, request_body) + McpTransportBase.update_mcp_client_info_in_session(session, request_body) + McpTransportBase.update_tool_call_id_in_session(session, request_body) @staticmethod def get_mcp_server_base_url(request: Request) -> str: @@ -198,7 +209,7 @@ class MCPTransportBase(ABC): status_code=400, detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", ) - return MCPTransportBase.convert_localhost_to_docker_host( + return McpTransportBase.convert_localhost_to_docker_host( mcp_server_base_url ).rstrip("/") @@ -233,7 +244,7 @@ class MCPTransportBase(ABC): @staticmethod async def hook_tool_call( session_id: str, session_store: McpSessionsManager, request_body: dict - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Hook to process the request JSON before sending it to the MCP server. @@ -243,11 +254,11 @@ class MCPTransportBase(ABC): request_body (dict): The request JSON to be processed. Returns: - Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean + tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean indicating whether the request was blocked. If the request is blocked, the dict will contain an error message else it will contain the original request. """ - message = MCPTransportBase.generate_request_message(request_body) + message = McpTransportBase.generate_request_message(request_body) # Check for blocking guardrails session = session_store.get_session(session_id) @@ -259,7 +270,7 @@ class MCPTransportBase(ABC): if ( guardrails_result and guardrails_result.get("errors", []) - and MCPTransportBase.check_if_new_errors( + and McpTransportBase.check_if_new_errors( session_id, session_store, guardrails_result ) ): @@ -269,15 +280,10 @@ class MCPTransportBase(ABC): message=message, guardrails_result=guardrails_result, ) - return { - "jsonrpc": "2.0", - "id": request_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - }, True + return McpTransportBase._create_jsonrpc_error_response( + request_body, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"], + ), True # Push trace to the explorer await session_store.add_message_to_session( @@ -291,7 +297,7 @@ class MCPTransportBase(ABC): session_store: McpSessionsManager, response_body: dict, is_tools_list=False, - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Hook to process the response JSON after receiving it from the MCP server. @@ -301,7 +307,7 @@ class MCPTransportBase(ABC): response_body (dict): The response JSON to be processed. is_tools_list (bool): Flag to indicate if the response is from a tools/list call. Returns: - Tuple[dict, bool]: A tuple containing the processed response JSON + tuple[dict, bool]: A tuple containing the processed response JSON and a boolean indicating whether the response was blocked. If the response is blocked, the dict will contain an error message else it will contain the original response. @@ -309,7 +315,7 @@ class MCPTransportBase(ABC): is_blocked = False result = response_body - message = MCPTransportBase.generate_response_message(result) + message = McpTransportBase.generate_response_message(result) session = session_store.get_session(session_id) guardrails_result = await session.get_guardrails_check_result( @@ -319,22 +325,17 @@ class MCPTransportBase(ABC): if ( guardrails_result and guardrails_result.get("errors", []) - and MCPTransportBase.check_if_new_errors( + and McpTransportBase.check_if_new_errors( session_id, session_store, guardrails_result ) ): is_blocked = True if not is_tools_list: - result = { - "jsonrpc": "2.0", - "id": response_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - } + result = McpTransportBase._create_jsonrpc_error_response( + response_body, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"], + ) else: # Special error response for tools/list result = { @@ -372,7 +373,7 @@ class MCPTransportBase(ABC): @staticmethod async def intercept_response( session_id: str, session_store: McpSessionsManager, response_body: dict - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Intercept the response and check for guardrails. This function is used to intercept responses and check for guardrails. @@ -386,7 +387,7 @@ class MCPTransportBase(ABC): response_body (dict): The response JSON to be processed. Returns: - Tuple[dict, bool]: A tuple containing the processed response JSON + tuple[dict, bool]: A tuple containing the processed response JSON and a boolean indicating whether the response was blocked. """ session = session_store.get_session(session_id) @@ -400,7 +401,7 @@ class MCPTransportBase(ABC): ( intercept_response_result, is_blocked, - ) = await MCPTransportBase.hook_tool_call_response( + ) = await McpTransportBase.hook_tool_call_response( session_id=session_id, session_store=session_store, response_body=response_body, @@ -414,7 +415,7 @@ class MCPTransportBase(ABC): ( intercept_response_result, is_blocked, - ) = await MCPTransportBase.hook_tool_call_response( + ) = await McpTransportBase.hook_tool_call_response( session_id=session_id, session_store=session_store, response_body={ @@ -431,9 +432,9 @@ class MCPTransportBase(ABC): return intercept_response_result, is_blocked @abstractmethod - async def initialize_session(self, *args, **kwargs) -> str: + async def initialize_session(self, **kwargs) -> str: """Initialize a session for this transport type.""" @abstractmethod - async def handle_communication(self, *args, **kwargs) -> Any: + async def handle_communication(self, **kwargs) -> Any: """Handle the main communication for this transport.""" diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 3109175..3f3d0c3 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -3,20 +3,20 @@ import asyncio import json import re -from typing import Any, AsyncGenerator, Optional, Tuple +from typing import Any, AsyncGenerator import httpx from httpx_sse import aconnect_sse, ServerSentEvent from fastapi import APIRouter, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from gateway.common.constants import CLIENT_TIMEOUT +from gateway.common.constants import CLIENT_TIMEOUT, CONTENT_TYPE_EVENT_STREAM from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8 from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase MCP_SERVER_POST_HEADERS = { "connection", @@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post( raise HTTPException(status_code=400, detail="Session does not exist") request_body = json.loads(await request.body()) - return await SSETransport(session_store).handle_post_request( + return await SseTransport(session_store).handle_post_request( request, session_id, request_body ) @@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream( request: Request, session_store: McpSessionsManager ) -> StreamingResponse: """Integration function for SSE GET route.""" - return await SSETransport(session_store).handle_sse_stream(request) + return await SseTransport(session_store).handle_sse_stream(request) -class SSETransport(MCPTransportBase): +class SseTransport(McpTransportBase): """ Server-Sent Events transport implementation for MCP communication. Handles HTTP-based SSE communication with message queuing. @@ -82,12 +82,11 @@ class SSETransport(MCPTransportBase): async def initialize_session( self, - *args, **kwargs, ) -> str: """Initialize or get existing SSE session.""" - session_id: Optional[str] = kwargs.get("session_id", None) - session_attributes: Optional[McpAttributes] = kwargs.get( + session_id: str | None = kwargs.get("session_id", None) + session_attributes: McpAttributes | None = kwargs.get( "session_attributes", None ) if session_id and self.session_store.session_exists(session_id): @@ -294,17 +293,17 @@ class SSETransport(MCPTransportBase): return StreamingResponse( event_generator(), - media_type="text/event-stream", + media_type=CONTENT_TYPE_EVENT_STREAM, headers={"X-Proxied-By": "mcp-gateway", **response_headers}, ) - async def handle_communication(self, *args, **kwargs) -> StreamingResponse: + async def handle_communication(self, **kwargs) -> StreamingResponse: """Main communication handler for SSE transport.""" return await self.handle_sse_stream(kwargs.get("request")) async def _handle_endpoint_event( self, sse: ServerSentEvent, sse_header_attributes: McpAttributes - ) -> Tuple[bytes, str]: + ) -> tuple[bytes, str]: """Handle endpoint event and initialize session if needed.""" match = re.search(r"session_id=([^&\s]+)", sse.data) session_id = match.group(1) if match else None diff --git a/gateway/mcp/stdio.py b/gateway/mcp/stdio.py index 4a5fd73..c58b61c 100644 --- a/gateway/mcp/stdio.py +++ b/gateway/mcp/stdio.py @@ -7,7 +7,6 @@ import platform import select import subprocess import sys -from typing import Optional, Tuple from gateway.mcp.constants import UTF_8 from gateway.mcp.log import mcp_log, MCP_LOG_FILE @@ -15,7 +14,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpAttributes, McpSessionsManager, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase STATUS_EOF = "eof" STATUS_DATA = "data" @@ -23,7 +22,7 @@ STATUS_WAIT = "wait" mcp_sessions_manager = McpSessionsManager() -class StdioTransport(MCPTransportBase): +class StdioTransport(McpTransportBase): """ STDIO transport implementation for MCP communication. Handles subprocess-based communication with stdin/stdout/stderr. @@ -33,7 +32,7 @@ class StdioTransport(MCPTransportBase): super().__init__(session_store) self.mcp_process: subprocess.Popen = None - async def initialize_session(self, *args, **kwargs) -> str: + async def initialize_session(self, **kwargs) -> str: """Initialize session for stdio transport.""" session_attributes: McpAttributes = kwargs.get("session_attributes") session_id = self.generate_session_id() @@ -53,7 +52,7 @@ class StdioTransport(MCPTransportBase): mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}") return self.mcp_process - async def handle_communication(self, *args, **kwargs) -> None: + async def handle_communication(self, **kwargs) -> None: """Handle stdio communication loop.""" session_id: str = kwargs.get("session_id") mcp_process: subprocess.Popen = kwargs.get("mcp_process") @@ -210,7 +209,7 @@ class StdioTransport(MCPTransportBase): async def _wait_for_stdin_input( self, loop: asyncio.AbstractEventLoop, stdin_fd: int - ) -> Tuple[Optional[bytes], str]: + ) -> tuple[bytes | None, str]: """Platform-specific implementation to wait for and read input from stdin.""" if platform.system() == "Windows": await asyncio.sleep(0.01) @@ -261,7 +260,7 @@ async def create_stdio_transport_and_execute( ) -def split_args(args: list[str] = None) -> tuple[list[str], list[str]]: +def split_args(args: list[str] | None = None) -> tuple[list[str], list[str]]: """ Splits CLI arguments into two parts: 1. Arguments intended for the MCP gateway (everything before `--exec`) diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index 39a2b01..15e67e8 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -1,14 +1,19 @@ """Gateway service to forward requests to the MCP Streamable HTTP servers""" import json -from typing import Any, Optional, Union +from typing import Any import httpx from httpx_sse import aconnect_sse from fastapi import APIRouter, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from gateway.common.constants import CLIENT_TIMEOUT +from gateway.common.constants import ( + CLIENT_TIMEOUT, + CONTENT_TYPE_HEADER, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, +) from gateway.mcp.constants import ( INVARIANT_SESSION_ID_PREFIX, MCP_CUSTOM_HEADER_PREFIX, @@ -18,16 +23,13 @@ from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase gateway = APIRouter() mcp_sessions_manager = McpSessionsManager() -CONTENT_TYPE_JSON = "application/json" -CONTENT_TYPE_SSE = "text/event-stream" -CONTENT_TYPE_HEADER = "content-type" MCP_SESSION_ID_HEADER = "mcp-session-id" -MCP_SERVER_POST_DELETE_HEADERS = { +MCP_SERVER_POST_AND_DELETE_HEADERS = { "connection", "accept", CONTENT_TYPE_HEADER, @@ -69,7 +71,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response: async def create_streamable_transport_and_handle_request( request: Request, method: str, session_store: McpSessionsManager -) -> Union[Response, StreamingResponse]: +) -> Response | StreamingResponse: """Integration function for streamable routes.""" streamable_transport = StreamableTransport(session_store) return await streamable_transport.handle_communication( @@ -77,7 +79,7 @@ async def create_streamable_transport_and_handle_request( ) -class StreamableTransport(MCPTransportBase): +class StreamableTransport(McpTransportBase): """ Streamable HTTP transport implementation for MCP communication. Handles HTTP POST/GET/DELETE requests with JSON and streaming responses. @@ -85,12 +87,11 @@ class StreamableTransport(MCPTransportBase): async def initialize_session( self, - *args, **kwargs, ) -> str: """Initialize streamable HTTP session.""" - session_id: Optional[str] = kwargs.get("session_id", None) - session_attributes: Optional[McpAttributes] = kwargs.get( + session_id: str | None = kwargs.get("session_id", None) + session_attributes: McpAttributes | None = kwargs.get( "session_attributes", None ) is_initialization_request: bool = kwargs.get("is_initialization_request", False) @@ -111,7 +112,7 @@ class StreamableTransport(MCPTransportBase): async def handle_post_request( self, request: Request, request_body: dict[str, Any] - ) -> Union[Response, StreamingResponse]: + ) -> Response | StreamingResponse: """Handle POST request to streamable endpoint.""" session_attributes = McpAttributes.from_request_headers(request.headers) session_id = request.headers.get(MCP_SESSION_ID_HEADER) @@ -188,7 +189,7 @@ class StreamableTransport(MCPTransportBase): return StreamingResponse( event_generator(), - media_type=CONTENT_TYPE_SSE, + media_type=CONTENT_TYPE_EVENT_STREAM, headers={"X-Proxied-By": "mcp-gateway", **response_headers}, ) @@ -222,9 +223,7 @@ class StreamableTransport(MCPTransportBase): print(f"[MCP DELETE] Request error: {str(e)}") raise HTTPException(status_code=500, detail="Request error") from e - async def handle_communication( - self, *args, **kwargs - ) -> Union[Response, StreamingResponse]: + async def handle_communication(self, **kwargs) -> Response | StreamingResponse: """Main communication handler for streamable transport.""" request = kwargs.get("request") method = kwargs.get("method", "POST") @@ -241,7 +240,7 @@ class StreamableTransport(MCPTransportBase): async def _process_non_init_request( self, session_id: str, request_body: dict[str, Any] - ) -> Optional[Response]: + ) -> Response | None: """Process non-initialization requests for guardrails.""" processed_request, is_blocked = await self.process_outgoing_request( session_id, request_body @@ -262,7 +261,7 @@ class StreamableTransport(MCPTransportBase): session_id: str, session_attributes: McpAttributes, is_initialization_request: bool, - ) -> Union[Response, StreamingResponse]: + ) -> Response | StreamingResponse: """Forward request to MCP server and handle response.""" async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: try: @@ -385,7 +384,7 @@ class StreamableTransport(MCPTransportBase): return StreamingResponse( event_generator(), - media_type=CONTENT_TYPE_SSE, + media_type=CONTENT_TYPE_EVENT_STREAM, headers=response_headers, ) @@ -395,6 +394,9 @@ class StreamableTransport(MCPTransportBase): """Update MCP response info in session metadata.""" session = self.session_store.get_session(session_id) self.update_mcp_server_in_session_metadata(session, response_body) + session.attributes.metadata["is_stateless_http_server"] = session_id.startswith( + INVARIANT_SESSION_ID_PREFIX + ) session.attributes.metadata["server_response_type"] = ( "json" if is_json_response else "sse" ) @@ -405,7 +407,7 @@ class StreamableTransport(MCPTransportBase): for k, v in request.headers.items(): if k.startswith(MCP_CUSTOM_HEADER_PREFIX): filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v - if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not ( + if k.lower() in MCP_SERVER_POST_AND_DELETE_HEADERS and not ( k.lower() == MCP_SESSION_ID_HEADER and v.startswith(INVARIANT_SESSION_ID_PREFIX) ): diff --git a/gateway/mcp/task_utils.py b/gateway/mcp/task_utils.py deleted file mode 100644 index 2ff8b87..0000000 --- a/gateway/mcp/task_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Task utilities for running async functions""" - -import asyncio -import concurrent.futures - -from contextlib import redirect_stdout -from typing import Any - -from gateway.mcp.log import MCP_LOG_FILE - - -def run_task_sync(async_func, *args, **kwargs) -> Any: - """ - Runs an asynchronous function synchronously in a separate - thread with its own event loop. This function blocks the calling - thread until completion or timeout (10 seconds). - - Args: - async_func: The async function to run - *args: Positional arguments to pass to the async function - **kwargs: Keyword arguments to pass to the async function - - Returns: - Any: The return value of the async function - """ - - def run_in_new_loop(): - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete( - async_func( - *args, - **kwargs, - ) - ) - finally: - loop.close() - - with redirect_stdout(MCP_LOG_FILE): - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_new_loop) - return future.result(timeout=10.0) diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 36c4a63..0fcbb1f 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -14,7 +14,12 @@ from gateway.common.config_manager import ( GatewayConfigManager, extract_guardrails_from_header, ) -from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS +from gateway.common.constants import ( + CLIENT_TIMEOUT, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, + IGNORED_HEADERS, +) from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet from gateway.common.request_context import RequestContext from gateway.converters.anthropic_to_invariant import ( @@ -64,7 +69,7 @@ def validate_headers(x_api_key: str = Header(None)): ) async def anthropic_v1_messages_gateway( request: Request, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ): @@ -162,7 +167,7 @@ async def get_guardrails_check_result( async def push_to_explorer( context: RequestContext, merged_response: dict[str, Any], - guardrails_execution_result: Optional[dict] = None, + guardrails_execution_result: dict | None = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} @@ -210,15 +215,18 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): self.anthropic_request: httpx.Request = anthropic_request # response data - self.response: Optional[httpx.Response] = None - self.response_string: Optional[str] = None - self.response_json: Optional[dict[str, Any]] = None + self.response: httpx.Response | None = None + self.response_string: str | None = None + self.response_json: dict[str, Any] | None = None # guardrailing response (if any) self.guardrails_execution_result = {} async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + """ + Check guardrails in a pipelined fashion, before processing the first + chunk (for input guardrailing). + """ if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, response_json={} @@ -249,8 +257,8 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): Response( content=error_chunk, status_code=400, - media_type="application/json", - headers={"content-type": "application/json"}, + media_type=CONTENT_TYPE_JSON, + headers={"content-type": CONTENT_TYPE_JSON}, ) ) @@ -263,7 +271,10 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, - detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}", + detail=( + "Invalid JSON response received from Anthropic: " + f"{self.response.text}, got error: {e}" + ), ) from e if self.response.status_code != 200: raise HTTPException( @@ -289,12 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): return Response( content=content, status_code=status_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(updated_headers), ) async def on_end(self): - """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" + """ + Checks guardrails after the response is received, and asynchronously + pushes to Explorer. + """ # ensure the response data is available assert self.response is not None, "response is None" assert self.response_json is not None, "response_json is None" @@ -383,7 +397,10 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): self.sse_buffer = "" # Buffer for incomplete events async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + """ + Check guardrails in a pipelined fashion, before processing the + first chunk (for input guardrailing). + """ if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( self.context, @@ -503,7 +520,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...", flush=True, ) - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Error processing event: {e}", flush=True) # on last stream chunk, run output guardrails @@ -536,7 +553,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): """Process the buffer and extract complete SSE events. Returns: - Tuple[List[str], str]: A tuple containing a list of + tuple[list[str], str]: A tuple containing a list of complete events and the remaining buffer with incomplete events. """ # Split on double newlines which separate SSE events @@ -582,7 +599,7 @@ async def handle_streaming_response( ) return StreamingResponse( - response.instrumented_event_generator(), media_type="text/event-stream" + response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM ) diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 9a836e7..d386929 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response @@ -16,6 +16,8 @@ from gateway.common.config_manager import ( ) from gateway.common.constants import ( CLIENT_TIMEOUT, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, IGNORED_HEADERS, ) from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet @@ -47,7 +49,7 @@ async def gemini_generate_content_gateway( api_version: str, model: str, endpoint: str, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer alt: str = Query( None, title="Response Format", description="Set to 'sse' for streaming" ), @@ -58,8 +60,10 @@ async def gemini_generate_content_gateway( if endpoint not in ["generateContent", "streamGenerateContent"]: return Response( content="Invalid endpoint - the only endpoints supported are: \ - /api/v1/gateway/gemini//models/:generateContent or \ - /api/v1/gateway//gemini/models/:generateContent", + /api/v1/gateway/gemini//models/:generateContent \ + /api/v1/gateway//gemini/models/:generateContent \ + /api/v1/gateway/gemini//models/:streamGenerateContent or \ + /api/v1/gateway//gemini/models/:streamGenerateContent", status_code=400, ) headers = { @@ -80,7 +84,11 @@ async def gemini_generate_content_gateway( request_json = json.loads(request_body_bytes) client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) - gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}" + gemini_api_url = ( + f"https://generativelanguage.googleapis.com/" + f"{api_version}/models/" + f"{model}:{endpoint}" + ) if alt == "sse": gemini_api_url += "?alt=sse" gemini_request = client.build_request( @@ -139,7 +147,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): } # guardrailing execution result (if any) - self.guardrails_execution_result: Optional[dict[str, Any]] = None + self.guardrails_execution_result: dict[str, Any] | None = None def make_refusal( self, @@ -301,7 +309,7 @@ async def stream_response( return StreamingResponse( event_generator(), - media_type="text/event-stream", + media_type=CONTENT_TYPE_EVENT_STREAM, ) @@ -407,7 +415,7 @@ async def get_guardrails_check_result( async def push_to_explorer( context: RequestContext, response_json: dict[str, Any], - guardrails_execution_result: Optional[dict] = None, + guardrails_execution_result: dict | None = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} @@ -456,11 +464,11 @@ class InstrumentedGeminiResponse(InstrumentedResponse): self.gemini_request: httpx.Request = gemini_request # response data - self.response: Optional[httpx.Response] = None - self.response_json: Optional[dict[str, Any]] = None + self.response: httpx.Response | None = None + self.response_json: dict[str, Any] | None = None # guardrails execution result (if any) - self.guardrails_execution_result: Optional[dict[str, Any]] = None + self.guardrails_execution_result: dict[str, Any] | None = None async def on_start(self): """ @@ -509,9 +517,9 @@ class InstrumentedGeminiResponse(InstrumentedResponse): Response( content=error_chunk, status_code=400, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers={ - "Content-Type": "application/json", + "Content-Type": CONTENT_TYPE_JSON, }, ) ) @@ -539,7 +547,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): return Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(self.response.headers), ) @@ -582,7 +590,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(self.response.headers), ) ) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 4d25871..1c8e1d9 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any, Optional +from typing import Any import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response @@ -16,6 +16,8 @@ from gateway.common.config_manager import ( ) from gateway.common.constants import ( CLIENT_TIMEOUT, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, IGNORED_HEADERS, ) from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet @@ -60,14 +62,14 @@ def make_cors_response(request: Request, allow_methods: str) -> Response: @gateway.options("/{dataset_name}/openai/chat/completions") @gateway.options("/openai/chat/completions") -async def openai_chat_completions_options(request: Request, dataset_name: str = None): +async def openai_chat_completions_options(request: Request): """Enables CORS for the OpenAI chat completions endpoint""" return make_cors_response(request, allow_methods="POST") @gateway.options("/{dataset_name}/openai/models") @gateway.options("/openai/models") -async def openai_models_options(request: Request, dataset_name: str = None): +async def openai_models_options(request: Request): """Enables CORS for the OpenAI models endpoint""" return make_cors_response(request, allow_methods="GET") @@ -76,7 +78,7 @@ async def openai_models_options(request: Request, dataset_name: str = None): @gateway.get("/openai/models") async def openai_models_gateway( request: Request, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer ): """Proxy request to OpenAI /models endpoint""" headers = { @@ -110,7 +112,7 @@ async def openai_models_gateway( ) async def openai_chat_completions_gateway( request: Request, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ) -> Response: @@ -180,7 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.open_ai_request: httpx.Request = open_ai_request # guardrailing output (if any) - self.guardrails_execution_result: Optional[dict] = None + self.guardrails_execution_result: dict | None = None # merged_response will be updated with the data from the chunks in the stream # At the end of the stream, this will be sent to the explorer @@ -273,7 +275,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): } ) - # yield an extra error chunk (without preventing the original chunk to go through after) + # yield an extra error chunk (without preventing the original + # chunk to go through after) return ExtraItem(f"data: {error_chunk}\n\n".encode()) # push will happen in on_end @@ -324,7 +327,7 @@ async def handle_stream_response( ) return StreamingResponse( - response.instrumented_event_generator(), media_type="text/event-stream" + response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM ) @@ -483,7 +486,7 @@ def create_metadata( async def push_to_explorer( context: RequestContext, merged_response: dict[str, Any], - guardrails_execution_result: Optional[dict] = None, + guardrails_execution_result: dict | None = None, ) -> None: """Pushes the merged response to the Invariant Explorer""" # Only push the trace to explorer if the message is an end turn message @@ -569,11 +572,11 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.open_ai_request: httpx.Request = open_ai_request # request outputs - self.response: Optional[httpx.Response] = None - self.response_json: Optional[dict[str, Any]] = None + self.response: httpx.Response | None = None + self.response_json: dict[str, Any] | None = None # guardrailing output (if any) - self.guardrails_execution_result: Optional[dict] = None + self.guardrails_execution_result: dict | None = None async def on_start(self): """ @@ -606,7 +609,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): } ), status_code=400, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, ), end_of_stream=True, ) @@ -634,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): return Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(self.response.headers), ) @@ -686,7 +689,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, ), ) diff --git a/pyproject.toml b/pyproject.toml index d52c4cb..224398e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "invariant-gateway" -version = "0.0.5.2" +version = "0.0.6" description = "LLM proxy to observe and debug what your AI agents are doing" readme = "README.md" requires-python = ">=3.12" diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 82cdc82..365af35 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -7,7 +7,6 @@ import sys import time import uuid from pathlib import Path -from typing import Dict, List # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -49,7 +48,7 @@ class WeatherAgent: }, } - def get_response(self, messages: List[Dict]) -> List[Dict]: + def get_response(self, messages: list[dict]) -> list[dict]: """ Get the response from the agent for a given user query for weather. """ @@ -83,7 +82,7 @@ class WeatherAgent: else: return response_list - def get_streaming_response(self, messages: List[Dict]) -> List[Dict]: + def get_streaming_response(self, messages: list[dict]) -> list[dict]: """Get streaming response from the agent for a given user query for weather.""" response_list = [] diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index f86e289..ff5d700 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -2,7 +2,6 @@ import os import uuid - from resources.mcp.sse.client.main import run as mcp_sse_client_run from resources.mcp.stdio.client.main import run as mcp_stdio_client_run from resources.mcp.streamable.client.main import run as mcp_streamable_client_run @@ -12,7 +11,6 @@ import httpx import pytest import requests from datetime import datetime -from mcp.shared.exceptions import McpError # Taken from docker-compose.test.yml MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server" @@ -706,7 +704,7 @@ async def test_mcp_message_timestamps( """Test that MCP messages include timestamps""" project_name = "test-mcp-" + str(uuid.uuid4()) - # Run the MCP client and make the tool call + result = await _invoke_mcp_tool( transport, gateway_url, diff --git a/tests/integration/resources/mcp/sse/client/main.py b/tests/integration/resources/mcp/sse/client/main.py index d74b592..b804b15 100644 --- a/tests/integration/resources/mcp/sse/client/main.py +++ b/tests/integration/resources/mcp/sse/client/main.py @@ -11,7 +11,7 @@ async def run( gateway_url: str, tool_name: str, tool_args: dict[str, Any], - headers: dict[str, str] = None, + headers: dict[str, str] | None = None, ) -> types.CallToolResult | types.ListToolsResult: """ Run the MCP client with the given parameters. diff --git a/tests/integration/resources/mcp/stdio/client/main.py b/tests/integration/resources/mcp/stdio/client/main.py index 3aeaaf3..5afad84 100644 --- a/tests/integration/resources/mcp/stdio/client/main.py +++ b/tests/integration/resources/mcp/stdio/client/main.py @@ -3,7 +3,7 @@ import os from datetime import timedelta -from typing import Any, Optional +from typing import Any from mcp import ClientSession, StdioServerParameters, types from mcp.client.stdio import stdio_client @@ -14,7 +14,7 @@ def _get_server_params( project_name: str, server_script_path: str, push_to_explorer: bool, - metadata_keys: Optional[dict[str, str]] = None, + metadata_keys: dict[str, str] | None = None, ) -> StdioServerParameters: args = [ "--from", @@ -59,7 +59,7 @@ async def run( push_to_explorer: bool, tool_name: str, tool_args: dict[str, Any], - metadata_keys: Optional[dict[str, str]] = None, + metadata_keys: dict[str, str] | None = None, ) -> types.CallToolResult | types.ListToolsResult: """ Main function to setup the MCP client and server. diff --git a/tests/integration/resources/mcp/streamable/client/main.py b/tests/integration/resources/mcp/streamable/client/main.py index da7fc23..6d972c7 100644 --- a/tests/integration/resources/mcp/streamable/client/main.py +++ b/tests/integration/resources/mcp/streamable/client/main.py @@ -12,7 +12,7 @@ async def run( gateway_url: str, tool_name: str, tool_args: dict[str, Any], - headers: dict[str, str] = None, + headers: dict[str, str] | None = None, ) -> types.CallToolResult | types.ListToolsResult: """ Run the MCP client with the given parameters. diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3c6d3bd..438f6cc 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -2,7 +2,7 @@ import os import uuid -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal from httpx import Client from openai import OpenAI @@ -62,8 +62,8 @@ def get_gemini_client( async def create_dataset( explorer_api_url: str, invariant_authorization: str, - dataset_name: Optional[str] = None, -) -> Dict[str, Any]: + dataset_name: str | None = None, +) -> dict[str, Any]: """Create a dataset in the Explorer API.""" client = Client(base_url=explorer_api_url) response = client.post( @@ -85,7 +85,7 @@ async def add_guardrail_to_dataset( policy: str, action: Literal["block", "log"], invariant_authorization: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Add a guardrail to a dataset.""" client = Client(base_url=explorer_api_url) response = client.post(