From 9053d89f25a05343ba76ced7e728f23174155978 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 4 Jun 2025 13:59:41 +0200 Subject: [PATCH 1/6] Small cleanups in MCP related code. --- gateway/mcp/mcp_transport_base.py | 68 +++++++++++++++---------------- gateway/mcp/sse.py | 11 +++-- gateway/mcp/stdio.py | 8 ++-- gateway/mcp/streamable.py | 17 ++++---- gateway/mcp/task_utils.py | 42 ------------------- 5 files changed, 50 insertions(+), 96 deletions(-) delete mode 100644 gateway/mcp/task_utils.py diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index 7557472..6367460 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -4,7 +4,6 @@ MCP Transport Strategy Pattern Implementation This module defines an abstract base class for MCP transports. """ -import asyncio import json import re import uuid @@ -31,7 +30,7 @@ from gateway.mcp.log import format_errors_in_response from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager -class MCPTransportBase(ABC): +class McpTransportBase(ABC): """ Abstract base class for MCP transport strategies. @@ -53,7 +52,7 @@ class MCPTransportBase(ABC): """ # Update session with request information session = self.session_store.get_session(session_id) - MCPTransportBase.update_session_from_request(session, request_data) + McpTransportBase.update_session_from_request(session, request_data) # Refresh guardrails await session.load_guardrails() @@ -75,10 +74,10 @@ class MCPTransportBase(ABC): """ # Update session with server information session = self.session_store.get_session(session_id) - MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data) + McpTransportBase.update_mcp_server_in_session_metadata(session, response_data) # Intercept and apply guardrails to response - return await MCPTransportBase.intercept_response( + return await McpTransportBase.intercept_response( session_id, self.session_store, response_data ) @@ -87,6 +86,17 @@ class MCPTransportBase(ABC): method = request_data.get(MCP_METHOD) return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS] + @staticmethod + def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict: + return { + "jsonrpc": "2.0", + "id": request_body.get("id"), + "error": { + "code": -32600, + "message": message, + }, + } + async def _intercept_outgoing_request( self, session_id: str, request_data: dict[str, Any] ) -> Tuple[dict[str, Any], bool]: @@ -96,11 +106,11 @@ class MCPTransportBase(ABC): interception_result = request_data is_blocked = False if method == MCP_TOOL_CALL: - interception_result, is_blocked = await MCPTransportBase.hook_tool_call( + interception_result, is_blocked = await McpTransportBase.hook_tool_call( session_id, self.session_store, request_data ) elif method == MCP_LIST_TOOLS: - interception_result, is_blocked = await MCPTransportBase.hook_tool_call( + interception_result, is_blocked = await McpTransportBase.hook_tool_call( session_id=session_id, session_store=self.session_store, request_body={ @@ -152,8 +162,8 @@ class MCPTransportBase(ABC): @staticmethod def update_session_from_request(session: McpSession, request_body: dict) -> None: """Update the MCP client information and request id in the session.""" - MCPTransportBase.update_mcp_client_info_in_session(session, request_body) - MCPTransportBase.update_tool_call_id_in_session(session, request_body) + McpTransportBase.update_mcp_client_info_in_session(session, request_body) + McpTransportBase.update_tool_call_id_in_session(session, request_body) @staticmethod def get_mcp_server_base_url(request: Request) -> str: @@ -164,7 +174,7 @@ class MCPTransportBase(ABC): status_code=400, detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header", ) - return MCPTransportBase.convert_localhost_to_docker_host( + return McpTransportBase.convert_localhost_to_docker_host( mcp_server_base_url ).rstrip("/") @@ -233,7 +243,7 @@ class MCPTransportBase(ABC): if ( guardrails_result and guardrails_result.get("errors", []) - and MCPTransportBase.check_if_new_errors( + and McpTransportBase.check_if_new_errors( session_id, session_store, guardrails_result ) ): @@ -243,15 +253,10 @@ class MCPTransportBase(ABC): message=message, guardrails_result=guardrails_result, ) - return { - "jsonrpc": "2.0", - "id": request_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - }, True + return McpTransportBase._create_jsonrpc_error_response( + request_body, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"], + ), True # Push trace to the explorer await session_store.add_message_to_session( @@ -298,22 +303,17 @@ class MCPTransportBase(ABC): if ( guardrails_result and guardrails_result.get("errors", []) - and MCPTransportBase.check_if_new_errors( + and McpTransportBase.check_if_new_errors( session_id, session_store, guardrails_result ) ): is_blocked = True if not is_tools_list: - result = { - "jsonrpc": "2.0", - "id": response_body.get("id"), - "error": { - "code": -32600, - "message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE - % guardrails_result["errors"], - }, - } + result = McpTransportBase._create_jsonrpc_error_response( + response_body, + INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"], + ) else: # Special error response for tools/list result = { @@ -379,7 +379,7 @@ class MCPTransportBase(ABC): ( intercept_response_result, is_blocked, - ) = await MCPTransportBase.hook_tool_call_response( + ) = await McpTransportBase.hook_tool_call_response( session_id=session_id, session_store=session_store, response_body=response_body, @@ -393,7 +393,7 @@ class MCPTransportBase(ABC): ( intercept_response_result, is_blocked, - ) = await MCPTransportBase.hook_tool_call_response( + ) = await McpTransportBase.hook_tool_call_response( session_id=session_id, session_store=session_store, response_body={ @@ -410,9 +410,9 @@ class MCPTransportBase(ABC): return intercept_response_result, is_blocked @abstractmethod - async def initialize_session(self, *args, **kwargs) -> str: + async def initialize_session(self, **kwargs) -> str: """Initialize a session for this transport type.""" @abstractmethod - async def handle_communication(self, *args, **kwargs) -> Any: + async def handle_communication(self, **kwargs) -> Any: """Handle the main communication for this transport.""" diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 3109175..998d5de 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -16,7 +16,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase MCP_SERVER_POST_HEADERS = { "connection", @@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post( raise HTTPException(status_code=400, detail="Session does not exist") request_body = json.loads(await request.body()) - return await SSETransport(session_store).handle_post_request( + return await SseTransport(session_store).handle_post_request( request, session_id, request_body ) @@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream( request: Request, session_store: McpSessionsManager ) -> StreamingResponse: """Integration function for SSE GET route.""" - return await SSETransport(session_store).handle_sse_stream(request) + return await SseTransport(session_store).handle_sse_stream(request) -class SSETransport(MCPTransportBase): +class SseTransport(McpTransportBase): """ Server-Sent Events transport implementation for MCP communication. Handles HTTP-based SSE communication with message queuing. @@ -82,7 +82,6 @@ class SSETransport(MCPTransportBase): async def initialize_session( self, - *args, **kwargs, ) -> str: """Initialize or get existing SSE session.""" @@ -298,7 +297,7 @@ class SSETransport(MCPTransportBase): headers={"X-Proxied-By": "mcp-gateway", **response_headers}, ) - async def handle_communication(self, *args, **kwargs) -> StreamingResponse: + async def handle_communication(self, **kwargs) -> StreamingResponse: """Main communication handler for SSE transport.""" return await self.handle_sse_stream(kwargs.get("request")) diff --git a/gateway/mcp/stdio.py b/gateway/mcp/stdio.py index 4a5fd73..a44fc27 100644 --- a/gateway/mcp/stdio.py +++ b/gateway/mcp/stdio.py @@ -15,7 +15,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpAttributes, McpSessionsManager, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase STATUS_EOF = "eof" STATUS_DATA = "data" @@ -23,7 +23,7 @@ STATUS_WAIT = "wait" mcp_sessions_manager = McpSessionsManager() -class StdioTransport(MCPTransportBase): +class StdioTransport(McpTransportBase): """ STDIO transport implementation for MCP communication. Handles subprocess-based communication with stdin/stdout/stderr. @@ -33,7 +33,7 @@ class StdioTransport(MCPTransportBase): super().__init__(session_store) self.mcp_process: subprocess.Popen = None - async def initialize_session(self, *args, **kwargs) -> str: + async def initialize_session(self, **kwargs) -> str: """Initialize session for stdio transport.""" session_attributes: McpAttributes = kwargs.get("session_attributes") session_id = self.generate_session_id() @@ -53,7 +53,7 @@ class StdioTransport(MCPTransportBase): mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}") return self.mcp_process - async def handle_communication(self, *args, **kwargs) -> None: + async def handle_communication(self, **kwargs) -> None: """Handle stdio communication loop.""" session_id: str = kwargs.get("session_id") mcp_process: subprocess.Popen = kwargs.get("mcp_process") diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index 39a2b01..b066fb7 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -1,7 +1,7 @@ """Gateway service to forward requests to the MCP Streamable HTTP servers""" import json -from typing import Any, Optional, Union +from typing import Any, Optional import httpx from httpx_sse import aconnect_sse @@ -18,7 +18,7 @@ from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, McpAttributes, ) -from gateway.mcp.mcp_transport_base import MCPTransportBase +from gateway.mcp.mcp_transport_base import McpTransportBase gateway = APIRouter() mcp_sessions_manager = McpSessionsManager() @@ -69,7 +69,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response: async def create_streamable_transport_and_handle_request( request: Request, method: str, session_store: McpSessionsManager -) -> Union[Response, StreamingResponse]: +) -> Response | StreamingResponse: """Integration function for streamable routes.""" streamable_transport = StreamableTransport(session_store) return await streamable_transport.handle_communication( @@ -77,7 +77,7 @@ async def create_streamable_transport_and_handle_request( ) -class StreamableTransport(MCPTransportBase): +class StreamableTransport(McpTransportBase): """ Streamable HTTP transport implementation for MCP communication. Handles HTTP POST/GET/DELETE requests with JSON and streaming responses. @@ -85,7 +85,6 @@ class StreamableTransport(MCPTransportBase): async def initialize_session( self, - *args, **kwargs, ) -> str: """Initialize streamable HTTP session.""" @@ -111,7 +110,7 @@ class StreamableTransport(MCPTransportBase): async def handle_post_request( self, request: Request, request_body: dict[str, Any] - ) -> Union[Response, StreamingResponse]: + ) -> Response | StreamingResponse: """Handle POST request to streamable endpoint.""" session_attributes = McpAttributes.from_request_headers(request.headers) session_id = request.headers.get(MCP_SESSION_ID_HEADER) @@ -222,9 +221,7 @@ class StreamableTransport(MCPTransportBase): print(f"[MCP DELETE] Request error: {str(e)}") raise HTTPException(status_code=500, detail="Request error") from e - async def handle_communication( - self, *args, **kwargs - ) -> Union[Response, StreamingResponse]: + async def handle_communication(self, **kwargs) -> Response | StreamingResponse: """Main communication handler for streamable transport.""" request = kwargs.get("request") method = kwargs.get("method", "POST") @@ -262,7 +259,7 @@ class StreamableTransport(MCPTransportBase): session_id: str, session_attributes: McpAttributes, is_initialization_request: bool, - ) -> Union[Response, StreamingResponse]: + ) -> Response | StreamingResponse: """Forward request to MCP server and handle response.""" async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client: try: diff --git a/gateway/mcp/task_utils.py b/gateway/mcp/task_utils.py deleted file mode 100644 index 2ff8b87..0000000 --- a/gateway/mcp/task_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Task utilities for running async functions""" - -import asyncio -import concurrent.futures - -from contextlib import redirect_stdout -from typing import Any - -from gateway.mcp.log import MCP_LOG_FILE - - -def run_task_sync(async_func, *args, **kwargs) -> Any: - """ - Runs an asynchronous function synchronously in a separate - thread with its own event loop. This function blocks the calling - thread until completion or timeout (10 seconds). - - Args: - async_func: The async function to run - *args: Positional arguments to pass to the async function - **kwargs: Keyword arguments to pass to the async function - - Returns: - Any: The return value of the async function - """ - - def run_in_new_loop(): - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete( - async_func( - *args, - **kwargs, - ) - ) - finally: - loop.close() - - with redirect_stdout(MCP_LOG_FILE): - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_new_loop) - return future.result(timeout=10.0) From da03dbe7c5d886f252d103cfaa137d8fd7f2104b Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 4 Jun 2025 14:15:08 +0200 Subject: [PATCH 2/6] Move is_stateless_http_server metadata assignment to the Streamable route from the common metadata method. --- gateway/mcp/mcp_sessions_manager.py | 4 ---- gateway/mcp/mcp_transport_base.py | 8 ++++---- gateway/mcp/streamable.py | 7 +++++-- tests/integration/mcp/test_mcp.py | 3 --- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/gateway/mcp/mcp_sessions_manager.py b/gateway/mcp/mcp_sessions_manager.py index bb915d2..9add397 100644 --- a/gateway/mcp/mcp_sessions_manager.py +++ b/gateway/mcp/mcp_sessions_manager.py @@ -23,7 +23,6 @@ from gateway.integrations.explorer import ( fetch_guardrails_from_explorer, ) from gateway.integrations.guardrails import check_guardrails -from gateway.mcp.constants import INVARIANT_SESSION_ID_PREFIX def user_and_host() -> str: @@ -110,9 +109,6 @@ class McpSession(BaseModel): "system_user": user_and_host(), **(self.attributes.metadata or {}), } - metadata["is_stateless_http_server"] = self.session_id.startswith( - INVARIANT_SESSION_ID_PREFIX - ) return metadata async def get_guardrails_check_result( diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index 6367460..a0118c3 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -52,7 +52,7 @@ class McpTransportBase(ABC): """ # Update session with request information session = self.session_store.get_session(session_id) - McpTransportBase.update_session_from_request(session, request_data) + self.update_session_from_request(session, request_data) # Refresh guardrails await session.load_guardrails() @@ -74,7 +74,7 @@ class McpTransportBase(ABC): """ # Update session with server information session = self.session_store.get_session(session_id) - McpTransportBase.update_mcp_server_in_session_metadata(session, response_data) + self.update_mcp_server_in_session_metadata(session, response_data) # Intercept and apply guardrails to response return await McpTransportBase.intercept_response( @@ -106,11 +106,11 @@ class McpTransportBase(ABC): interception_result = request_data is_blocked = False if method == MCP_TOOL_CALL: - interception_result, is_blocked = await McpTransportBase.hook_tool_call( + interception_result, is_blocked = await self.hook_tool_call( session_id, self.session_store, request_data ) elif method == MCP_LIST_TOOLS: - interception_result, is_blocked = await McpTransportBase.hook_tool_call( + interception_result, is_blocked = await self.hook_tool_call( session_id=session_id, session_store=self.session_store, request_body={ diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index b066fb7..bf5f694 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -27,7 +27,7 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" CONTENT_TYPE_HEADER = "content-type" MCP_SESSION_ID_HEADER = "mcp-session-id" -MCP_SERVER_POST_DELETE_HEADERS = { +MCP_SERVER_POST_AND_DELETE_HEADERS = { "connection", "accept", CONTENT_TYPE_HEADER, @@ -392,6 +392,9 @@ class StreamableTransport(McpTransportBase): """Update MCP response info in session metadata.""" session = self.session_store.get_session(session_id) self.update_mcp_server_in_session_metadata(session, response_body) + session.attributes.metadata["is_stateless_http_server"] = session_id.startswith( + INVARIANT_SESSION_ID_PREFIX + ) session.attributes.metadata["server_response_type"] = ( "json" if is_json_response else "sse" ) @@ -402,7 +405,7 @@ class StreamableTransport(McpTransportBase): for k, v in request.headers.items(): if k.startswith(MCP_CUSTOM_HEADER_PREFIX): filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v - if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not ( + if k.lower() in MCP_SERVER_POST_AND_DELETE_HEADERS and not ( k.lower() == MCP_SESSION_ID_HEADER and v.startswith(INVARIANT_SESSION_ID_PREFIX) ): diff --git a/tests/integration/mcp/test_mcp.py b/tests/integration/mcp/test_mcp.py index 2606c2b..e65262b 100644 --- a/tests/integration/mcp/test_mcp.py +++ b/tests/integration/mcp/test_mcp.py @@ -2,7 +2,6 @@ import os import uuid - from resources.mcp.sse.client.main import run as mcp_sse_client_run from resources.mcp.stdio.client.main import run as mcp_stdio_client_run from resources.mcp.streamable.client.main import run as mcp_streamable_client_run @@ -12,8 +11,6 @@ import httpx import pytest import requests -from mcp.shared.exceptions import McpError - # Taken from docker-compose.test.yml MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server" MCP_SSE_SERVER_PORT = 8123 From f184c488e8757b0a695455726d6d94a18378e8a1 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 4 Jun 2025 14:23:22 +0200 Subject: [PATCH 3/6] Bump to version 0.0.6 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d52c4cb..224398e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "invariant-gateway" -version = "0.0.5.2" +version = "0.0.6" description = "LLM proxy to observe and debug what your AI agents are doing" readme = "README.md" requires-python = ">=3.12" From 24d47c4585b7f55516a567feb6559b8a5820ab84 Mon Sep 17 00:00:00 2001 From: Hemang Date: Wed, 4 Jun 2025 15:04:56 +0200 Subject: [PATCH 4/6] Update gemini route to include streamGenerateContent in allowed endpoints response. --- gateway/integrations/guardrails.py | 4 ++-- gateway/routes/gemini.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 3424881..aa73dad 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -364,11 +364,11 @@ async def check_guardrails( json={ "messages": messages, "policies": [g.content for g in guardrails], - "parameters": context.guardrails_parameters or {} + "parameters": context.guardrails_parameters or {}, }, headers={ "Authorization": context.get_guardrailing_authorization(), - "Accept": "application/json" + "Accept": "application/json", }, timeout=5, ) diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 9a836e7..6e7fe26 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -58,8 +58,10 @@ async def gemini_generate_content_gateway( if endpoint not in ["generateContent", "streamGenerateContent"]: return Response( content="Invalid endpoint - the only endpoints supported are: \ - /api/v1/gateway/gemini//models/:generateContent or \ - /api/v1/gateway//gemini/models/:generateContent", + /api/v1/gateway/gemini//models/:generateContent \ + /api/v1/gateway//gemini/models/:generateContent \ + /api/v1/gateway/gemini//models/:streamGenerateContent or \ + /api/v1/gateway//gemini/models/:streamGenerateContent", status_code=400, ) headers = { From cd6d6a50b05838e21cad1f38a587ef89bef3f151 Mon Sep 17 00:00:00 2001 From: Hemang Date: Thu, 5 Jun 2025 09:55:02 +0200 Subject: [PATCH 5/6] Small changes related to constants and sorting order of imports. --- gateway/common/config_manager.py | 1 - gateway/common/constants.py | 3 +++ gateway/common/guardrails.py | 4 +--- gateway/common/request_context.py | 8 +++---- gateway/integrations/explorer.py | 5 ++-- gateway/integrations/guardrails.py | 11 +++++---- gateway/mcp/log.py | 1 - gateway/mcp/sse.py | 4 ++-- gateway/mcp/streamable.py | 14 ++++++----- gateway/routes/anthropic.py | 37 ++++++++++++++++++++++-------- gateway/routes/gemini.py | 18 ++++++++++----- gateway/routes/open_ai.py | 13 +++++++---- 12 files changed, 74 insertions(+), 45 deletions(-) diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 98d83e5..0c6095d 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -10,7 +10,6 @@ from httpx import HTTPStatusError from gateway.common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet - def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]: """ Extracts the guardrailing policy from the request headers if present. diff --git a/gateway/common/constants.py b/gateway/common/constants.py index a84eaea..20ab1d8 100644 --- a/gateway/common/constants.py +++ b/gateway/common/constants.py @@ -16,3 +16,6 @@ IGNORED_HEADERS = [ CLIENT_TIMEOUT = 60.0 +CONTENT_TYPE_HEADER = "content-type" +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_EVENT_STREAM = "text/event-stream" diff --git a/gateway/common/guardrails.py b/gateway/common/guardrails.py index e4164c1..ec7a94e 100644 --- a/gateway/common/guardrails.py +++ b/gateway/common/guardrails.py @@ -1,11 +1,9 @@ """Common guardrails data class.""" +from dataclasses import dataclass from enum import Enum from typing import List -from dataclasses import dataclass - - class GuardrailAction(str, Enum): """Enum representing the action to be taken for guardrail rules.""" diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py index 11b477d..fbce8a3 100644 --- a/gateway/common/request_context.py +++ b/gateway/common/request_context.py @@ -5,11 +5,11 @@ from typing import Any, Dict, Optional import fastapi -from gateway.common.config_manager import GatewayConfig -from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from gateway.common.authorization import ( extract_guardrail_service_authorization_from_headers, ) +from gateway.common.config_manager import GatewayConfig +from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction @dataclass(frozen=True) @@ -25,7 +25,7 @@ class RequestContext: # the set of guardrails to enforce for this request guardrails: Optional[GuardrailRuleSet] = None config: Dict[str, Any] = None - + # extra parameters available as input. during guardrail evaluation guardrails_parameters: Optional[Dict[str, Any]] = None @@ -100,7 +100,7 @@ class RequestContext: guardrails=guardrails, config=context_config, _created_via_factory=True, - guardrails_parameters=guardrails_parameters + guardrails_parameters=guardrails_parameters, ) def get_guardrailing_authorization(self) -> Optional[str]: diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index e547831..621f1c2 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -2,8 +2,9 @@ import os import json - from typing import Any, Dict, List + +import httpx from fastapi import HTTPException from gateway.common.constants import DEFAULT_API_URL @@ -12,8 +13,6 @@ from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse from invariant_sdk.types.annotations import AnnotationCreate -import httpx - def create_annotations_from_guardrails_errors( guardrails_errors: List[dict], diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index aa73dad..2b3f912 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -376,11 +376,14 @@ async def check_guardrails( if result.status_code == 401: raise HTTPException( status_code=401, - detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format( - INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER + detail=( + "The provided Invariant API key is not valid for guardrail checking. " + "Please ensure you are using the correct API key or pass an " + "alternative API key for guardrail checking specifically via the " + f"'{INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER}' header." ), ) - raise Exception( + raise Exception( # pylint: disable=broad-exception-raised f"Guardrails check failed: {result.status_code} - {result.text}" ) guardrails_result = result.json() @@ -412,7 +415,7 @@ async def check_guardrails( return aggregated_errors except HTTPException as e: raise e - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Failed to verify guardrails: {e}") # make sure runtime errors are also visible in e.g. Explorer return { diff --git a/gateway/mcp/log.py b/gateway/mcp/log.py index 7287513..388355b 100644 --- a/gateway/mcp/log.py +++ b/gateway/mcp/log.py @@ -2,7 +2,6 @@ import os import sys - from builtins import print as builtins_print os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True) diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 998d5de..6c6b971 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -10,7 +10,7 @@ from httpx_sse import aconnect_sse, ServerSentEvent from fastapi import APIRouter, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from gateway.common.constants import CLIENT_TIMEOUT +from gateway.common.constants import CLIENT_TIMEOUT, CONTENT_TYPE_EVENT_STREAM from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8 from gateway.mcp.mcp_sessions_manager import ( McpSessionsManager, @@ -293,7 +293,7 @@ class SseTransport(McpTransportBase): return StreamingResponse( event_generator(), - media_type="text/event-stream", + media_type=CONTENT_TYPE_EVENT_STREAM, headers={"X-Proxied-By": "mcp-gateway", **response_headers}, ) diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index bf5f694..930abbd 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -8,7 +8,12 @@ from httpx_sse import aconnect_sse from fastapi import APIRouter, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from gateway.common.constants import CLIENT_TIMEOUT +from gateway.common.constants import ( + CLIENT_TIMEOUT, + CONTENT_TYPE_HEADER, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, +) from gateway.mcp.constants import ( INVARIANT_SESSION_ID_PREFIX, MCP_CUSTOM_HEADER_PREFIX, @@ -23,9 +28,6 @@ from gateway.mcp.mcp_transport_base import McpTransportBase gateway = APIRouter() mcp_sessions_manager = McpSessionsManager() -CONTENT_TYPE_JSON = "application/json" -CONTENT_TYPE_SSE = "text/event-stream" -CONTENT_TYPE_HEADER = "content-type" MCP_SESSION_ID_HEADER = "mcp-session-id" MCP_SERVER_POST_AND_DELETE_HEADERS = { "connection", @@ -187,7 +189,7 @@ class StreamableTransport(McpTransportBase): return StreamingResponse( event_generator(), - media_type=CONTENT_TYPE_SSE, + media_type=CONTENT_TYPE_EVENT_STREAM, headers={"X-Proxied-By": "mcp-gateway", **response_headers}, ) @@ -382,7 +384,7 @@ class StreamableTransport(McpTransportBase): return StreamingResponse( event_generator(), - media_type=CONTENT_TYPE_SSE, + media_type=CONTENT_TYPE_EVENT_STREAM, headers=response_headers, ) diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 36c4a63..e455b11 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -14,7 +14,12 @@ from gateway.common.config_manager import ( GatewayConfigManager, extract_guardrails_from_header, ) -from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS +from gateway.common.constants import ( + CLIENT_TIMEOUT, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, + IGNORED_HEADERS, +) from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet from gateway.common.request_context import RequestContext from gateway.converters.anthropic_to_invariant import ( @@ -218,7 +223,10 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): self.guardrails_execution_result = {} async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + """ + Check guardrails in a pipelined fashion, before processing the first + chunk (for input guardrailing). + """ if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( self.context, action=GuardrailAction.BLOCK, response_json={} @@ -249,8 +257,8 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): Response( content=error_chunk, status_code=400, - media_type="application/json", - headers={"content-type": "application/json"}, + media_type=CONTENT_TYPE_JSON, + headers={"content-type": CONTENT_TYPE_JSON}, ) ) @@ -263,7 +271,10 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): except json.JSONDecodeError as e: raise HTTPException( status_code=self.response.status_code, - detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}", + detail=( + "Invalid JSON response received from Anthropic: " + f"{self.response.text}, got error: {e}" + ), ) from e if self.response.status_code != 200: raise HTTPException( @@ -289,12 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): return Response( content=content, status_code=status_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(updated_headers), ) async def on_end(self): - """Checks guardrails after the response is received, and asynchronously pushes to Explorer.""" + """ + Checks guardrails after the response is received, and asynchronously + pushes to Explorer. + """ # ensure the response data is available assert self.response is not None, "response is None" assert self.response_json is not None, "response_json is None" @@ -383,7 +397,10 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): self.sse_buffer = "" # Buffer for incomplete events async def on_start(self): - """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + """ + Check guardrails in a pipelined fashion, before processing the + first chunk (for input guardrailing). + """ if self.context.guardrails: self.guardrails_execution_result = await get_guardrails_check_result( self.context, @@ -503,7 +520,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...", flush=True, ) - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Error processing event: {e}", flush=True) # on last stream chunk, run output guardrails @@ -582,7 +599,7 @@ async def handle_streaming_response( ) return StreamingResponse( - response.instrumented_event_generator(), media_type="text/event-stream" + response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM ) diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 6e7fe26..14a3a63 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -16,6 +16,8 @@ from gateway.common.config_manager import ( ) from gateway.common.constants import ( CLIENT_TIMEOUT, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, IGNORED_HEADERS, ) from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet @@ -82,7 +84,11 @@ async def gemini_generate_content_gateway( request_json = json.loads(request_body_bytes) client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) - gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}" + gemini_api_url = ( + f"https://generativelanguage.googleapis.com/" + f"{api_version}/models/" + f"{model}:{endpoint}" + ) if alt == "sse": gemini_api_url += "?alt=sse" gemini_request = client.build_request( @@ -303,7 +309,7 @@ async def stream_response( return StreamingResponse( event_generator(), - media_type="text/event-stream", + media_type=CONTENT_TYPE_EVENT_STREAM, ) @@ -511,9 +517,9 @@ class InstrumentedGeminiResponse(InstrumentedResponse): Response( content=error_chunk, status_code=400, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers={ - "Content-Type": "application/json", + "Content-Type": CONTENT_TYPE_JSON, }, ) ) @@ -541,7 +547,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): return Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(self.response.headers), ) @@ -584,7 +590,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(self.response.headers), ) ) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 4d25871..2d8501e 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -16,6 +16,8 @@ from gateway.common.config_manager import ( ) from gateway.common.constants import ( CLIENT_TIMEOUT, + CONTENT_TYPE_JSON, + CONTENT_TYPE_EVENT_STREAM, IGNORED_HEADERS, ) from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet @@ -273,7 +275,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): } ) - # yield an extra error chunk (without preventing the original chunk to go through after) + # yield an extra error chunk (without preventing the original + # chunk to go through after) return ExtraItem(f"data: {error_chunk}\n\n".encode()) # push will happen in on_end @@ -324,7 +327,7 @@ async def handle_stream_response( ) return StreamingResponse( - response.instrumented_event_generator(), media_type="text/event-stream" + response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM ) @@ -606,7 +609,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): } ), status_code=400, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, ), end_of_stream=True, ) @@ -634,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): return Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, headers=dict(self.response.headers), ) @@ -686,7 +689,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): Response( content=response_string, status_code=response_code, - media_type="application/json", + media_type=CONTENT_TYPE_JSON, ), ) From 42a9c1cc30ea32f02d49009024b902bf513d4357 Mon Sep 17 00:00:00 2001 From: Hemang Date: Thu, 5 Jun 2025 11:17:35 +0200 Subject: [PATCH 6/6] Readability changes. --- gateway/__main__.py | 6 +- gateway/common/authorization.py | 11 +- gateway/common/config_manager.py | 9 +- gateway/common/guardrails.py | 5 +- gateway/common/request_context.py | 32 +-- gateway/integrations/explorer.py | 18 +- gateway/integrations/guardrails.py | 14 +- gateway/mcp/mcp_sessions_manager.py | 205 +++++++++--------- gateway/mcp/mcp_transport_base.py | 24 +- gateway/mcp/sse.py | 8 +- gateway/mcp/stdio.py | 5 +- gateway/mcp/streamable.py | 8 +- gateway/routes/anthropic.py | 12 +- gateway/routes/gemini.py | 14 +- gateway/routes/open_ai.py | 20 +- .../test_anthropic_with_tool_call.py | 5 +- .../resources/mcp/sse/client/main.py | 2 +- .../resources/mcp/stdio/client/main.py | 6 +- .../resources/mcp/streamable/client/main.py | 2 +- tests/integration/utils.py | 8 +- 20 files changed, 204 insertions(+), 210 deletions(-) diff --git a/gateway/__main__.py b/gateway/__main__.py index de26a66..f994253 100644 --- a/gateway/__main__.py +++ b/gateway/__main__.py @@ -7,8 +7,6 @@ import subprocess import sys import time -from typing import Optional - from gateway.mcp import stdio as mcp_stdio from gateway.mcp.log import mcp_log @@ -64,7 +62,7 @@ def ensure_network_exists(network_name: str = "invariant-explorer-web") -> bool: return False -def setup_guardrails(guardrails_file_path: Optional[str] = None) -> bool: +def setup_guardrails(guardrails_file_path: str | None = None) -> bool: """Configure guardrails if specified.""" if not guardrails_file_path: return True @@ -105,7 +103,7 @@ def build(): return False -def up(guardrails_file_path: Optional[str] = None): +def up(guardrails_file_path: str | None = None): """Set up the local server for the Invariant Gateway.""" # Ensure network exists if not ensure_network_exists(): diff --git a/gateway/common/authorization.py b/gateway/common/authorization.py index 124b8d3..b713950 100644 --- a/gateway/common/authorization.py +++ b/gateway/common/authorization.py @@ -1,6 +1,5 @@ """Common Authorization functions used in the gateway.""" -from typing import Tuple, Optional from fastapi import HTTPException, Request INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization" @@ -10,7 +9,7 @@ API_KEYS_SEPARATOR = ";invariant-auth=" def extract_guardrail_service_authorization_from_headers( request: Request, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: """ Extracts the optional Invariant-Guardrails-Authorization authorization header from the request. @@ -22,10 +21,10 @@ def extract_guardrail_service_authorization_from_headers( def extract_authorization_from_headers( request: Request, - dataset_name: Optional[str] = None, - llm_provider_api_key_header: Optional[str] = None, - llm_provider_fallback_api_key_headers: Optional[list[str]] = None, -) -> Tuple[Optional[str], Optional[str]]: + dataset_name: str | None = None, + llm_provider_api_key_header: str | None = None, + llm_provider_fallback_api_key_headers: list[str] | None = None, +) -> tuple[str | None, str | None]: """ Extracts the Invariant authorization and LLM Provider API key from the request headers. diff --git a/gateway/common/config_manager.py b/gateway/common/config_manager.py index 0c6095d..af2f310 100644 --- a/gateway/common/config_manager.py +++ b/gateway/common/config_manager.py @@ -3,14 +3,14 @@ import asyncio import os import threading -from typing import Optional import fastapi from httpx import HTTPStatusError from gateway.common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet -def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]: + +def extract_policy_from_headers(request: fastapi.Request | None) -> str | None: """ Extracts the guardrailing policy from the request headers if present. @@ -78,7 +78,7 @@ class GatewayConfigManager: _lock = threading.Lock() @classmethod - def get_config(cls, request: fastapi.Request = None) -> GatewayConfig: + def get_config(cls) -> GatewayConfig: """Initializes and returns the gateway configuration using double-checked locking.""" local_config = cls._config_instance @@ -94,7 +94,7 @@ class GatewayConfigManager: async def extract_guardrails_from_header( request: fastapi.Request, -) -> Optional[GuardrailRuleSet]: +) -> GuardrailRuleSet | None: """ Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding GuardrailRuleSet. If no guardrails are provided, returns None. @@ -114,3 +114,4 @@ async def extract_guardrails_from_header( blocking_guardrails=guardrails, logging_guardrails=[], ) + return None diff --git a/gateway/common/guardrails.py b/gateway/common/guardrails.py index ec7a94e..73c35f7 100644 --- a/gateway/common/guardrails.py +++ b/gateway/common/guardrails.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from enum import Enum -from typing import List class GuardrailAction(str, Enum): """Enum representing the action to be taken for guardrail rules.""" @@ -25,5 +24,5 @@ class Guardrail: class GuardrailRuleSet: """Grouped guardrail rules separated by their action.""" - blocking_guardrails: List[Guardrail] - logging_guardrails: List[Guardrail] + blocking_guardrails: list[Guardrail] + logging_guardrails: list[Guardrail] diff --git a/gateway/common/request_context.py b/gateway/common/request_context.py index fbce8a3..2f6afa0 100644 --- a/gateway/common/request_context.py +++ b/gateway/common/request_context.py @@ -1,7 +1,7 @@ """Common Request context data class.""" from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any import fastapi @@ -16,18 +16,18 @@ from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailActi class RequestContext: """Structured context for a request. Must be created via `RequestContext.create()`.""" - request_json: Dict[str, Any] - dataset_name: Optional[str] = None + request_json: dict[str, Any] + dataset_name: str | None = None # authorization to use for invariant service like explorer - invariant_authorization: Optional[str] = None + invariant_authorization: str | None = None # authorization to use for invariant guardrailing specifically - guardrail_authorization: Optional[str] = None + guardrail_authorization: str | None = None # the set of guardrails to enforce for this request - guardrails: Optional[GuardrailRuleSet] = None - config: Dict[str, Any] = None + guardrails: GuardrailRuleSet | None = None + config: dict[str, Any] | None = None # extra parameters available as input. during guardrail evaluation - guardrails_parameters: Optional[Dict[str, Any]] = None + guardrails_parameters: dict[str, Any] | None = None _created_via_factory: bool = field( default=False, init=True, repr=False, compare=False @@ -42,13 +42,13 @@ class RequestContext: @classmethod def create( cls, - request_json: Dict[str, Any], - dataset_name: Optional[str] = None, - invariant_authorization: Optional[str] = None, - guardrails: Optional[GuardrailRuleSet] = None, - config: Optional[GatewayConfig] = None, - request: fastapi.Request = None, - guardrails_parameters: Optional[Dict[str, Any]] = None, + request_json: dict[str, Any], + dataset_name: str | None = None, + invariant_authorization: str | None = None, + guardrails: GuardrailRuleSet | None = None, + config: GatewayConfig | None = None, + request: fastapi.Request | None = None, + guardrails_parameters: dict[str, Any] | None = None, ) -> "RequestContext": """Creates a new RequestContext instance, applying default guardrails if needed.""" @@ -103,7 +103,7 @@ class RequestContext: guardrails_parameters=guardrails_parameters, ) - def get_guardrailing_authorization(self) -> Optional[str]: + def get_guardrailing_authorization(self) -> str | None: """ Returns the authorization to use for the guardrailing service. diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index 621f1c2..007ccc0 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -2,7 +2,7 @@ import os import json -from typing import Any, Dict, List +from typing import Any import httpx from fastapi import HTTPException @@ -15,8 +15,8 @@ from invariant_sdk.types.annotations import AnnotationCreate def create_annotations_from_guardrails_errors( - guardrails_errors: List[dict], -) -> List[AnnotationCreate]: + guardrails_errors: list[dict], +) -> list[AnnotationCreate]: """Create Explorer annotations from the guardrails errors.""" annotations = [] @@ -67,7 +67,7 @@ def create_annotations_from_guardrails_errors( return remove_duplicates(annotations) -def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]: +def remove_duplicates(annotations: list[AnnotationCreate]) -> list[AnnotationCreate]: """ Remove duplicate annotations based on content, address, and extra_metadata. @@ -98,18 +98,18 @@ def get_explorer_api_url() -> str: async def push_trace( - messages: List[List[Dict[str, Any]]], + messages: list[list[dict[str, Any]]], dataset_name: str, invariant_authorization: str, - annotations: List[List[AnnotationCreate]] = None, - metadata: List[Dict[str, Any]] = None, + annotations: list[list[AnnotationCreate]] | None = None, + metadata: list[dict[str, Any]] | None = None, ) -> PushTracesResponse: """Pushes traces to the dataset on the Invariant Explorer. If a dataset with the given name does not exist, it will be created. Args: - messages (List[List[Dict[str, Any]]]): List of messages to push. + messages (listlistdict[str, Any]]]): List of messages to push. dataset_name (str): Name of the dataset. invariant_authorization (str): Value of the invariant-authorization header. @@ -134,7 +134,7 @@ async def push_trace( ) try: return await client.push_trace(request) - except Exception as e: + except Exception as e: # pylint: disable=broad-except print(f"Failed to push trace: {e}") return {"error": str(e)} diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 2b3f912..7b2a28f 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -3,7 +3,7 @@ import asyncio import os import time -from typing import Any, Dict, List +from typing import Any from functools import wraps from fastapi import HTTPException @@ -339,22 +339,22 @@ class InstrumentedResponse(InstrumentedStreamingResponse): async def check_guardrails( - messages: List[Dict[str, Any]], - guardrails: List[Guardrail], + messages: list[dict[str, Any]], + guardrails: list[Guardrail], context: RequestContext, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Checks guardrails on the list of messages. This calls the batch check API of the Guardrails service. Args: - messages (List[Dict[str, Any]]): List of messages to verify the guardrails against. - guardrails (List[Guardrail]): The guardrails to check against. + messages (list[dict[str, Any]]): List of messages to verify the guardrails against. + guardrails (list[Guardrail]): The guardrails to check against. invariant_authorization (str): Value of the invariant-authorization header. Returns: - Dict: Response containing guardrail check results. + dict: Response containing guardrail check results. """ async with httpx.AsyncClient() as client: url = os.getenv("GUARDRAILS_API_URL", DEFAULT_API_URL).rstrip("/") diff --git a/gateway/mcp/mcp_sessions_manager.py b/gateway/mcp/mcp_sessions_manager.py index 9add397..ece6a3e 100644 --- a/gateway/mcp/mcp_sessions_manager.py +++ b/gateway/mcp/mcp_sessions_manager.py @@ -7,7 +7,7 @@ import getpass import os import random import socket -from typing import Any, Optional +from typing import Any from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.append_messages import AppendMessagesRequest @@ -32,6 +32,105 @@ def user_and_host() -> str: return f"{username}@{hostname}" +class McpAttributes(BaseModel): + """ + A Pydantic model to represent MCP attributes. + This can be initialized using HTTP headers for SSE and Streamable transports. + This can also be initialized using CLI arguments for the Stdio transport. + """ + + push_explorer: bool + explorer_dataset: str + invariant_api_key: str | None = None + verbose: bool | None = False + metadata: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_request_headers(cls, headers: Headers) -> "McpAttributes": + """ + Create an instance from FastAPI request headers. + + Args: + headers: FastAPI Request headers + + Returns: + McpAttributes: An instance with values extracted from headers + """ + # Extract and process header values + project_name = headers.get("INVARIANT-PROJECT-NAME") + push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower() + invariant_api_key = headers.get("INVARIANT-API-KEY") + + # Determine explorer_dataset + if project_name: + explorer_dataset = project_name + else: + explorer_dataset = f"mcp-capture-{random.randint(1, 100)}" + + # Determine push_explorer + push_explorer = push_explorer_header == "true" + + # Create and return instance + return cls( + push_explorer=push_explorer, + explorer_dataset=explorer_dataset, + invariant_api_key=invariant_api_key, + ) + + @classmethod + def from_cli_args(cls, cli_args: list) -> "McpAttributes": + """ + Create an instance from command line arguments. + + Args: + cli_args: List of command line arguments + + Returns: + McpAttributes: An instance with values extracted from CLI arguments + """ + parser = argparse.ArgumentParser(description="MCP Gateway") + parser.add_argument( + "--project-name", + help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.", + type=str, + default=f"mcp-capture-{random.randint(1, 100)}", + ) + parser.add_argument( + "--push-explorer", + help="Enable pushing traces to Invariant Explorer", + action="store_true", + ) + parser.add_argument( + "--verbose", + help="Enable verbose logging", + action="store_true", + ) + parser.add_argument( + "--failure-response-format", + help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)", + type=str, + default="error", + ) + + config, extra_args = parser.parse_known_args(cli_args) + + metadata: dict[str, Any] = {} + for arg in extra_args: + assert "=" in arg, f"Invalid extra metadata argument: {arg}" + key, value = arg.split("=") + assert key.startswith( + "--metadata-" + ), f"Invalid extra metadata argument: {arg}, must start with --metadata-" + key = key[len("--metadata-") :] + metadata[key] = value + + return cls( + push_explorer=config.push_explorer, + explorer_dataset=config.project_name, + verbose=config.verbose, + metadata=metadata, + ) + class McpSession(BaseModel): """ @@ -40,9 +139,9 @@ class McpSession(BaseModel): session_id: str messages: list[dict[str, Any]] = Field(default_factory=list) - attributes: Optional["McpAttributes"] = None + attributes: McpAttributes | None = None id_to_method_mapping: dict[int, str] = Field(default_factory=dict) - trace_id: Optional[str] = None + trace_id: str | None = None last_trace_length: int = 0 annotations: list[dict[str, Any]] = Field(default_factory=list) guardrails: GuardrailRuleSet = Field( @@ -260,106 +359,6 @@ class McpSession(BaseModel): return messages -class McpAttributes(BaseModel): - """ - A Pydantic model to represent MCP attributes. - This can be initialized using HTTP headers for SSE and Streamable transports. - This can also be initialized using CLI arguments for the Stdio transport. - """ - - push_explorer: bool - explorer_dataset: str - invariant_api_key: Optional[str] = None - verbose: Optional[bool] = False - metadata: dict[str, Any] = Field(default_factory=dict) - - @classmethod - def from_request_headers(cls, headers: Headers) -> "McpAttributes": - """ - Create an instance from FastAPI request headers. - - Args: - headers: FastAPI Request headers - - Returns: - McpAttributes: An instance with values extracted from headers - """ - # Extract and process header values - project_name = headers.get("INVARIANT-PROJECT-NAME") - push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower() - invariant_api_key = headers.get("INVARIANT-API-KEY") - - # Determine explorer_dataset - if project_name: - explorer_dataset = project_name - else: - explorer_dataset = f"mcp-capture-{random.randint(1, 100)}" - - # Determine push_explorer - push_explorer = push_explorer_header == "true" - - # Create and return instance - return cls( - push_explorer=push_explorer, - explorer_dataset=explorer_dataset, - invariant_api_key=invariant_api_key, - ) - - @classmethod - def from_cli_args(cls, cli_args: list) -> "McpAttributes": - """ - Create an instance from command line arguments. - - Args: - cli_args: List of command line arguments - - Returns: - McpAttributes: An instance with values extracted from CLI arguments - """ - parser = argparse.ArgumentParser(description="MCP Gateway") - parser.add_argument( - "--project-name", - help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.", - type=str, - default=f"mcp-capture-{random.randint(1, 100)}", - ) - parser.add_argument( - "--push-explorer", - help="Enable pushing traces to Invariant Explorer", - action="store_true", - ) - parser.add_argument( - "--verbose", - help="Enable verbose logging", - action="store_true", - ) - parser.add_argument( - "--failure-response-format", - help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)", - type=str, - default="error", - ) - - config, extra_args = parser.parse_known_args(cli_args) - - metadata: dict[str, Any] = {} - for arg in extra_args: - assert "=" in arg, f"Invalid extra metadata argument: {arg}" - key, value = arg.split("=") - assert key.startswith( - "--metadata-" - ), f"Invalid extra metadata argument: {arg}, must start with --metadata-" - key = key[len("--metadata-") :] - metadata[key] = value - - return cls( - push_explorer=config.push_explorer, - explorer_dataset=config.project_name, - verbose=config.verbose, - metadata=metadata, - ) - - class McpSessionsManager: """ A class to manage MCP sessions and their messages. diff --git a/gateway/mcp/mcp_transport_base.py b/gateway/mcp/mcp_transport_base.py index a0118c3..a767e6d 100644 --- a/gateway/mcp/mcp_transport_base.py +++ b/gateway/mcp/mcp_transport_base.py @@ -8,7 +8,7 @@ import json import re import uuid from abc import ABC, abstractmethod -from typing import Any, Tuple +from typing import Any from fastapi import Request, HTTPException from gateway.common.guardrails import GuardrailAction @@ -43,12 +43,12 @@ class McpTransportBase(ABC): async def process_outgoing_request( self, session_id: str, request_data: dict[str, Any] - ) -> Tuple[dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """ Template method for processing outgoing requests to MCP server. Returns: - Tuple[processed_request_data, is_blocked] + tuple[processed_request_data, is_blocked] """ # Update session with request information session = self.session_store.get_session(session_id) @@ -65,12 +65,12 @@ class McpTransportBase(ABC): async def process_incoming_response( self, session_id: str, response_data: dict[str, Any] - ) -> Tuple[dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """ Template method for processing incoming responses from MCP server. Returns: - Tuple[processed_response, is_blocked] + tuple[processed_response, is_blocked] """ # Update session with server information session = self.session_store.get_session(session_id) @@ -99,7 +99,7 @@ class McpTransportBase(ABC): async def _intercept_outgoing_request( self, session_id: str, request_data: dict[str, Any] - ) -> Tuple[dict[str, Any], bool]: + ) -> tuple[dict[str, Any], bool]: """Common request interception logic for guardrails.""" method = request_data.get(MCP_METHOD) @@ -209,7 +209,7 @@ class McpTransportBase(ABC): @staticmethod async def hook_tool_call( session_id: str, session_store: McpSessionsManager, request_body: dict - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Hook to process the request JSON before sending it to the MCP server. @@ -219,7 +219,7 @@ class McpTransportBase(ABC): request_body (dict): The request JSON to be processed. Returns: - Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean + tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean indicating whether the request was blocked. If the request is blocked, the dict will contain an error message else it will contain the original request. """ @@ -270,7 +270,7 @@ class McpTransportBase(ABC): session_store: McpSessionsManager, response_body: dict, is_tools_list=False, - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Hook to process the response JSON after receiving it from the MCP server. @@ -280,7 +280,7 @@ class McpTransportBase(ABC): response_body (dict): The response JSON to be processed. is_tools_list (bool): Flag to indicate if the response is from a tools/list call. Returns: - Tuple[dict, bool]: A tuple containing the processed response JSON + tuple[dict, bool]: A tuple containing the processed response JSON and a boolean indicating whether the response was blocked. If the response is blocked, the dict will contain an error message else it will contain the original response. @@ -351,7 +351,7 @@ class McpTransportBase(ABC): @staticmethod async def intercept_response( session_id: str, session_store: McpSessionsManager, response_body: dict - ) -> Tuple[dict, bool]: + ) -> tuple[dict, bool]: """ Intercept the response and check for guardrails. This function is used to intercept responses and check for guardrails. @@ -365,7 +365,7 @@ class McpTransportBase(ABC): response_body (dict): The response JSON to be processed. Returns: - Tuple[dict, bool]: A tuple containing the processed response JSON + tuple[dict, bool]: A tuple containing the processed response JSON and a boolean indicating whether the response was blocked. """ session = session_store.get_session(session_id) diff --git a/gateway/mcp/sse.py b/gateway/mcp/sse.py index 6c6b971..3f3d0c3 100644 --- a/gateway/mcp/sse.py +++ b/gateway/mcp/sse.py @@ -3,7 +3,7 @@ import asyncio import json import re -from typing import Any, AsyncGenerator, Optional, Tuple +from typing import Any, AsyncGenerator import httpx from httpx_sse import aconnect_sse, ServerSentEvent @@ -85,8 +85,8 @@ class SseTransport(McpTransportBase): **kwargs, ) -> str: """Initialize or get existing SSE session.""" - session_id: Optional[str] = kwargs.get("session_id", None) - session_attributes: Optional[McpAttributes] = kwargs.get( + session_id: str | None = kwargs.get("session_id", None) + session_attributes: McpAttributes | None = kwargs.get( "session_attributes", None ) if session_id and self.session_store.session_exists(session_id): @@ -303,7 +303,7 @@ class SseTransport(McpTransportBase): async def _handle_endpoint_event( self, sse: ServerSentEvent, sse_header_attributes: McpAttributes - ) -> Tuple[bytes, str]: + ) -> tuple[bytes, str]: """Handle endpoint event and initialize session if needed.""" match = re.search(r"session_id=([^&\s]+)", sse.data) session_id = match.group(1) if match else None diff --git a/gateway/mcp/stdio.py b/gateway/mcp/stdio.py index a44fc27..c58b61c 100644 --- a/gateway/mcp/stdio.py +++ b/gateway/mcp/stdio.py @@ -7,7 +7,6 @@ import platform import select import subprocess import sys -from typing import Optional, Tuple from gateway.mcp.constants import UTF_8 from gateway.mcp.log import mcp_log, MCP_LOG_FILE @@ -210,7 +209,7 @@ class StdioTransport(McpTransportBase): async def _wait_for_stdin_input( self, loop: asyncio.AbstractEventLoop, stdin_fd: int - ) -> Tuple[Optional[bytes], str]: + ) -> tuple[bytes | None, str]: """Platform-specific implementation to wait for and read input from stdin.""" if platform.system() == "Windows": await asyncio.sleep(0.01) @@ -261,7 +260,7 @@ async def create_stdio_transport_and_execute( ) -def split_args(args: list[str] = None) -> tuple[list[str], list[str]]: +def split_args(args: list[str] | None = None) -> tuple[list[str], list[str]]: """ Splits CLI arguments into two parts: 1. Arguments intended for the MCP gateway (everything before `--exec`) diff --git a/gateway/mcp/streamable.py b/gateway/mcp/streamable.py index 930abbd..15e67e8 100644 --- a/gateway/mcp/streamable.py +++ b/gateway/mcp/streamable.py @@ -1,7 +1,7 @@ """Gateway service to forward requests to the MCP Streamable HTTP servers""" import json -from typing import Any, Optional +from typing import Any import httpx from httpx_sse import aconnect_sse @@ -90,8 +90,8 @@ class StreamableTransport(McpTransportBase): **kwargs, ) -> str: """Initialize streamable HTTP session.""" - session_id: Optional[str] = kwargs.get("session_id", None) - session_attributes: Optional[McpAttributes] = kwargs.get( + session_id: str | None = kwargs.get("session_id", None) + session_attributes: McpAttributes | None = kwargs.get( "session_attributes", None ) is_initialization_request: bool = kwargs.get("is_initialization_request", False) @@ -240,7 +240,7 @@ class StreamableTransport(McpTransportBase): async def _process_non_init_request( self, session_id: str, request_body: dict[str, Any] - ) -> Optional[Response]: + ) -> Response | None: """Process non-initialization requests for guardrails.""" processed_request, is_blocked = await self.process_outgoing_request( session_id, request_body diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index e455b11..0fcbb1f 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -69,7 +69,7 @@ def validate_headers(x_api_key: str = Header(None)): ) async def anthropic_v1_messages_gateway( request: Request, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ): @@ -167,7 +167,7 @@ async def get_guardrails_check_result( async def push_to_explorer( context: RequestContext, merged_response: dict[str, Any], - guardrails_execution_result: Optional[dict] = None, + guardrails_execution_result: dict | None = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} @@ -215,9 +215,9 @@ class InstrumentedAnthropicResponse(InstrumentedResponse): self.anthropic_request: httpx.Request = anthropic_request # response data - self.response: Optional[httpx.Response] = None - self.response_string: Optional[str] = None - self.response_json: Optional[dict[str, Any]] = None + self.response: httpx.Response | None = None + self.response_string: str | None = None + self.response_json: dict[str, Any] | None = None # guardrailing response (if any) self.guardrails_execution_result = {} @@ -553,7 +553,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse): """Process the buffer and extract complete SSE events. Returns: - Tuple[List[str], str]: A tuple containing a list of + tuple[list[str], str]: A tuple containing a list of complete events and the remaining buffer with incomplete events. """ # Split on double newlines which separate SSE events diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 14a3a63..d386929 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response @@ -49,7 +49,7 @@ async def gemini_generate_content_gateway( api_version: str, model: str, endpoint: str, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer alt: str = Query( None, title="Response Format", description="Set to 'sse' for streaming" ), @@ -147,7 +147,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): } # guardrailing execution result (if any) - self.guardrails_execution_result: Optional[dict[str, Any]] = None + self.guardrails_execution_result: dict[str, Any] | None = None def make_refusal( self, @@ -415,7 +415,7 @@ async def get_guardrails_check_result( async def push_to_explorer( context: RequestContext, response_json: dict[str, Any], - guardrails_execution_result: Optional[dict] = None, + guardrails_execution_result: dict | None = None, ) -> None: """Pushes the full trace to the Invariant Explorer""" guardrails_execution_result = guardrails_execution_result or {} @@ -464,11 +464,11 @@ class InstrumentedGeminiResponse(InstrumentedResponse): self.gemini_request: httpx.Request = gemini_request # response data - self.response: Optional[httpx.Response] = None - self.response_json: Optional[dict[str, Any]] = None + self.response: httpx.Response | None = None + self.response_json: dict[str, Any] | None = None # guardrails execution result (if any) - self.guardrails_execution_result: Optional[dict[str, Any]] = None + self.guardrails_execution_result: dict[str, Any] | None = None async def on_start(self): """ diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 2d8501e..1c8e1d9 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any, Optional +from typing import Any import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response @@ -62,14 +62,14 @@ def make_cors_response(request: Request, allow_methods: str) -> Response: @gateway.options("/{dataset_name}/openai/chat/completions") @gateway.options("/openai/chat/completions") -async def openai_chat_completions_options(request: Request, dataset_name: str = None): +async def openai_chat_completions_options(request: Request): """Enables CORS for the OpenAI chat completions endpoint""" return make_cors_response(request, allow_methods="POST") @gateway.options("/{dataset_name}/openai/models") @gateway.options("/openai/models") -async def openai_models_options(request: Request, dataset_name: str = None): +async def openai_models_options(request: Request): """Enables CORS for the OpenAI models endpoint""" return make_cors_response(request, allow_methods="GET") @@ -78,7 +78,7 @@ async def openai_models_options(request: Request, dataset_name: str = None): @gateway.get("/openai/models") async def openai_models_gateway( request: Request, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer ): """Proxy request to OpenAI /models endpoint""" headers = { @@ -112,7 +112,7 @@ async def openai_models_gateway( ) async def openai_chat_completions_gateway( request: Request, - dataset_name: str = None, # This is None if the client doesn't want to push to Explorer + dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header), ) -> Response: @@ -182,7 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): self.open_ai_request: httpx.Request = open_ai_request # guardrailing output (if any) - self.guardrails_execution_result: Optional[dict] = None + self.guardrails_execution_result: dict | None = None # merged_response will be updated with the data from the chunks in the stream # At the end of the stream, this will be sent to the explorer @@ -486,7 +486,7 @@ def create_metadata( async def push_to_explorer( context: RequestContext, merged_response: dict[str, Any], - guardrails_execution_result: Optional[dict] = None, + guardrails_execution_result: dict | None = None, ) -> None: """Pushes the merged response to the Invariant Explorer""" # Only push the trace to explorer if the message is an end turn message @@ -572,11 +572,11 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): self.open_ai_request: httpx.Request = open_ai_request # request outputs - self.response: Optional[httpx.Response] = None - self.response_json: Optional[dict[str, Any]] = None + self.response: httpx.Response | None = None + self.response_json: dict[str, Any] | None = None # guardrailing output (if any) - self.guardrails_execution_result: Optional[dict] = None + self.guardrails_execution_result: dict | None = None async def on_start(self): """ diff --git a/tests/integration/anthropic/test_anthropic_with_tool_call.py b/tests/integration/anthropic/test_anthropic_with_tool_call.py index 82cdc82..365af35 100644 --- a/tests/integration/anthropic/test_anthropic_with_tool_call.py +++ b/tests/integration/anthropic/test_anthropic_with_tool_call.py @@ -7,7 +7,6 @@ import sys import time import uuid from pathlib import Path -from typing import Dict, List # Add integration folder (parent) to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -49,7 +48,7 @@ class WeatherAgent: }, } - def get_response(self, messages: List[Dict]) -> List[Dict]: + def get_response(self, messages: list[dict]) -> list[dict]: """ Get the response from the agent for a given user query for weather. """ @@ -83,7 +82,7 @@ class WeatherAgent: else: return response_list - def get_streaming_response(self, messages: List[Dict]) -> List[Dict]: + def get_streaming_response(self, messages: list[dict]) -> list[dict]: """Get streaming response from the agent for a given user query for weather.""" response_list = [] diff --git a/tests/integration/resources/mcp/sse/client/main.py b/tests/integration/resources/mcp/sse/client/main.py index d74b592..b804b15 100644 --- a/tests/integration/resources/mcp/sse/client/main.py +++ b/tests/integration/resources/mcp/sse/client/main.py @@ -11,7 +11,7 @@ async def run( gateway_url: str, tool_name: str, tool_args: dict[str, Any], - headers: dict[str, str] = None, + headers: dict[str, str] | None = None, ) -> types.CallToolResult | types.ListToolsResult: """ Run the MCP client with the given parameters. diff --git a/tests/integration/resources/mcp/stdio/client/main.py b/tests/integration/resources/mcp/stdio/client/main.py index 3aeaaf3..5afad84 100644 --- a/tests/integration/resources/mcp/stdio/client/main.py +++ b/tests/integration/resources/mcp/stdio/client/main.py @@ -3,7 +3,7 @@ import os from datetime import timedelta -from typing import Any, Optional +from typing import Any from mcp import ClientSession, StdioServerParameters, types from mcp.client.stdio import stdio_client @@ -14,7 +14,7 @@ def _get_server_params( project_name: str, server_script_path: str, push_to_explorer: bool, - metadata_keys: Optional[dict[str, str]] = None, + metadata_keys: dict[str, str] | None = None, ) -> StdioServerParameters: args = [ "--from", @@ -59,7 +59,7 @@ async def run( push_to_explorer: bool, tool_name: str, tool_args: dict[str, Any], - metadata_keys: Optional[dict[str, str]] = None, + metadata_keys: dict[str, str] | None = None, ) -> types.CallToolResult | types.ListToolsResult: """ Main function to setup the MCP client and server. diff --git a/tests/integration/resources/mcp/streamable/client/main.py b/tests/integration/resources/mcp/streamable/client/main.py index da7fc23..6d972c7 100644 --- a/tests/integration/resources/mcp/streamable/client/main.py +++ b/tests/integration/resources/mcp/streamable/client/main.py @@ -12,7 +12,7 @@ async def run( gateway_url: str, tool_name: str, tool_args: dict[str, Any], - headers: dict[str, str] = None, + headers: dict[str, str] | None = None, ) -> types.CallToolResult | types.ListToolsResult: """ Run the MCP client with the given parameters. diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3c6d3bd..438f6cc 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -2,7 +2,7 @@ import os import uuid -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal from httpx import Client from openai import OpenAI @@ -62,8 +62,8 @@ def get_gemini_client( async def create_dataset( explorer_api_url: str, invariant_authorization: str, - dataset_name: Optional[str] = None, -) -> Dict[str, Any]: + dataset_name: str | None = None, +) -> dict[str, Any]: """Create a dataset in the Explorer API.""" client = Client(base_url=explorer_api_url) response = client.post( @@ -85,7 +85,7 @@ async def add_guardrail_to_dataset( policy: str, action: Literal["block", "log"], invariant_authorization: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Add a guardrail to a dataset.""" client = Client(base_url=explorer_api_url) response = client.post(