mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-07-01 08:45:32 +02:00
Merge remote-tracking branch 'origin/main' into feature-message-timestamps
This commit is contained in:
+2
-4
@@ -7,8 +7,6 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from gateway.mcp import stdio as mcp_stdio
|
||||
from gateway.mcp.log import mcp_log
|
||||
|
||||
@@ -64,7 +62,7 @@ def ensure_network_exists(network_name: str = "invariant-explorer-web") -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def setup_guardrails(guardrails_file_path: Optional[str] = None) -> bool:
|
||||
def setup_guardrails(guardrails_file_path: str | None = None) -> bool:
|
||||
"""Configure guardrails if specified."""
|
||||
if not guardrails_file_path:
|
||||
return True
|
||||
@@ -105,7 +103,7 @@ def build():
|
||||
return False
|
||||
|
||||
|
||||
def up(guardrails_file_path: Optional[str] = None):
|
||||
def up(guardrails_file_path: str | None = None):
|
||||
"""Set up the local server for the Invariant Gateway."""
|
||||
# Ensure network exists
|
||||
if not ensure_network_exists():
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Common Authorization functions used in the gateway."""
|
||||
|
||||
from typing import Tuple, Optional
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
|
||||
@@ -10,7 +9,7 @@ API_KEYS_SEPARATOR = ";invariant-auth="
|
||||
|
||||
def extract_guardrail_service_authorization_from_headers(
|
||||
request: Request,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extracts the optional Invariant-Guardrails-Authorization authorization header from the request.
|
||||
|
||||
@@ -22,10 +21,10 @@ def extract_guardrail_service_authorization_from_headers(
|
||||
|
||||
def extract_authorization_from_headers(
|
||||
request: Request,
|
||||
dataset_name: Optional[str] = None,
|
||||
llm_provider_api_key_header: Optional[str] = None,
|
||||
llm_provider_fallback_api_key_headers: Optional[list[str]] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
dataset_name: str | None = None,
|
||||
llm_provider_api_key_header: str | None = None,
|
||||
llm_provider_fallback_api_key_headers: list[str] | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extracts the Invariant authorization and LLM Provider API key from the request headers.
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import fastapi
|
||||
from httpx import HTTPStatusError
|
||||
@@ -11,7 +10,7 @@ from httpx import HTTPStatusError
|
||||
from gateway.common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet
|
||||
|
||||
|
||||
def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]:
|
||||
def extract_policy_from_headers(request: fastapi.Request | None) -> str | None:
|
||||
"""
|
||||
Extracts the guardrailing policy from the request headers if present.
|
||||
|
||||
@@ -79,7 +78,7 @@ class GatewayConfigManager:
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, request: fastapi.Request = None) -> GatewayConfig:
|
||||
def get_config(cls) -> GatewayConfig:
|
||||
"""Initializes and returns the gateway configuration using double-checked locking."""
|
||||
local_config = cls._config_instance
|
||||
|
||||
@@ -95,7 +94,7 @@ class GatewayConfigManager:
|
||||
|
||||
async def extract_guardrails_from_header(
|
||||
request: fastapi.Request,
|
||||
) -> Optional[GuardrailRuleSet]:
|
||||
) -> GuardrailRuleSet | None:
|
||||
"""
|
||||
Extracts Invariant-Guardrails from the request header if provided, and returns a corresponding
|
||||
GuardrailRuleSet. If no guardrails are provided, returns None.
|
||||
@@ -115,3 +114,4 @@ async def extract_guardrails_from_header(
|
||||
blocking_guardrails=guardrails,
|
||||
logging_guardrails=[],
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -16,3 +16,6 @@ IGNORED_HEADERS = [
|
||||
|
||||
CLIENT_TIMEOUT = 60.0
|
||||
|
||||
CONTENT_TYPE_HEADER = "content-type"
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_EVENT_STREAM = "text/event-stream"
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""Common guardrails data class."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class GuardrailAction(str, Enum):
|
||||
"""Enum representing the action to be taken for guardrail rules."""
|
||||
@@ -27,5 +24,5 @@ class Guardrail:
|
||||
class GuardrailRuleSet:
|
||||
"""Grouped guardrail rules separated by their action."""
|
||||
|
||||
blocking_guardrails: List[Guardrail]
|
||||
logging_guardrails: List[Guardrail]
|
||||
blocking_guardrails: list[Guardrail]
|
||||
logging_guardrails: list[Guardrail]
|
||||
|
||||
@@ -1,33 +1,33 @@
|
||||
"""Common Request context data class."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
import fastapi
|
||||
|
||||
from gateway.common.config_manager import GatewayConfig
|
||||
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
from gateway.common.authorization import (
|
||||
extract_guardrail_service_authorization_from_headers,
|
||||
)
|
||||
from gateway.common.config_manager import GatewayConfig
|
||||
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestContext:
|
||||
"""Structured context for a request. Must be created via `RequestContext.create()`."""
|
||||
|
||||
request_json: Dict[str, Any]
|
||||
dataset_name: Optional[str] = None
|
||||
request_json: dict[str, Any]
|
||||
dataset_name: str | None = None
|
||||
# authorization to use for invariant service like explorer
|
||||
invariant_authorization: Optional[str] = None
|
||||
invariant_authorization: str | None = None
|
||||
# authorization to use for invariant guardrailing specifically
|
||||
guardrail_authorization: Optional[str] = None
|
||||
guardrail_authorization: str | None = None
|
||||
# the set of guardrails to enforce for this request
|
||||
guardrails: Optional[GuardrailRuleSet] = None
|
||||
config: Dict[str, Any] = None
|
||||
|
||||
guardrails: GuardrailRuleSet | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
# extra parameters available as input.<key> during guardrail evaluation
|
||||
guardrails_parameters: Optional[Dict[str, Any]] = None
|
||||
guardrails_parameters: dict[str, Any] | None = None
|
||||
|
||||
_created_via_factory: bool = field(
|
||||
default=False, init=True, repr=False, compare=False
|
||||
@@ -42,13 +42,13 @@ class RequestContext:
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
request_json: Dict[str, Any],
|
||||
dataset_name: Optional[str] = None,
|
||||
invariant_authorization: Optional[str] = None,
|
||||
guardrails: Optional[GuardrailRuleSet] = None,
|
||||
config: Optional[GatewayConfig] = None,
|
||||
request: fastapi.Request = None,
|
||||
guardrails_parameters: Optional[Dict[str, Any]] = None,
|
||||
request_json: dict[str, Any],
|
||||
dataset_name: str | None = None,
|
||||
invariant_authorization: str | None = None,
|
||||
guardrails: GuardrailRuleSet | None = None,
|
||||
config: GatewayConfig | None = None,
|
||||
request: fastapi.Request | None = None,
|
||||
guardrails_parameters: dict[str, Any] | None = None,
|
||||
) -> "RequestContext":
|
||||
"""Creates a new RequestContext instance, applying default guardrails if needed."""
|
||||
|
||||
@@ -100,10 +100,10 @@ class RequestContext:
|
||||
guardrails=guardrails,
|
||||
config=context_config,
|
||||
_created_via_factory=True,
|
||||
guardrails_parameters=guardrails_parameters
|
||||
guardrails_parameters=guardrails_parameters,
|
||||
)
|
||||
|
||||
def get_guardrailing_authorization(self) -> Optional[str]:
|
||||
def get_guardrailing_authorization(self) -> str | None:
|
||||
"""
|
||||
Returns the authorization to use for the guardrailing service.
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from typing import Any, Dict, List
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from gateway.common.constants import DEFAULT_API_URL
|
||||
@@ -12,12 +13,10 @@ from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
||||
from invariant_sdk.types.annotations import AnnotationCreate
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def create_annotations_from_guardrails_errors(
|
||||
guardrails_errors: List[dict],
|
||||
) -> List[AnnotationCreate]:
|
||||
guardrails_errors: list[dict],
|
||||
) -> list[AnnotationCreate]:
|
||||
"""Create Explorer annotations from the guardrails errors."""
|
||||
annotations = []
|
||||
|
||||
@@ -68,7 +67,7 @@ def create_annotations_from_guardrails_errors(
|
||||
return remove_duplicates(annotations)
|
||||
|
||||
|
||||
def remove_duplicates(annotations: List[AnnotationCreate]) -> List[AnnotationCreate]:
|
||||
def remove_duplicates(annotations: list[AnnotationCreate]) -> list[AnnotationCreate]:
|
||||
"""
|
||||
Remove duplicate annotations based on content, address, and extra_metadata.
|
||||
|
||||
@@ -99,18 +98,18 @@ def get_explorer_api_url() -> str:
|
||||
|
||||
|
||||
async def push_trace(
|
||||
messages: List[List[Dict[str, Any]]],
|
||||
messages: list[list[dict[str, Any]]],
|
||||
dataset_name: str,
|
||||
invariant_authorization: str,
|
||||
annotations: List[List[AnnotationCreate]] = None,
|
||||
metadata: List[Dict[str, Any]] = None,
|
||||
annotations: list[list[AnnotationCreate]] | None = None,
|
||||
metadata: list[dict[str, Any]] | None = None,
|
||||
) -> PushTracesResponse:
|
||||
"""Pushes traces to the dataset on the Invariant Explorer.
|
||||
|
||||
If a dataset with the given name does not exist, it will be created.
|
||||
|
||||
Args:
|
||||
messages (List[List[Dict[str, Any]]]): List of messages to push.
|
||||
messages (listlistdict[str, Any]]]): List of messages to push.
|
||||
dataset_name (str): Name of the dataset.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
@@ -135,7 +134,7 @@ async def push_trace(
|
||||
)
|
||||
try:
|
||||
return await client.push_trace(request)
|
||||
except Exception as e:
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Failed to push trace: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -339,22 +339,22 @@ class InstrumentedResponse(InstrumentedStreamingResponse):
|
||||
|
||||
|
||||
async def check_guardrails(
|
||||
messages: List[Dict[str, Any]],
|
||||
guardrails: List[Guardrail],
|
||||
messages: list[dict[str, Any]],
|
||||
guardrails: list[Guardrail],
|
||||
context: RequestContext,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Checks guardrails on the list of messages.
|
||||
This calls the batch check API of the Guardrails service.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (List[Guardrail]): The guardrails to check against.
|
||||
messages (list[dict[str, Any]]): List of messages to verify the guardrails against.
|
||||
guardrails (list[Guardrail]): The guardrails to check against.
|
||||
invariant_authorization (str): Value of the
|
||||
invariant-authorization header.
|
||||
|
||||
Returns:
|
||||
Dict: Response containing guardrail check results.
|
||||
dict: Response containing guardrail check results.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = os.getenv("GUARDRAILS_API_URL", DEFAULT_API_URL).rstrip("/")
|
||||
@@ -364,11 +364,11 @@ async def check_guardrails(
|
||||
json={
|
||||
"messages": messages,
|
||||
"policies": [g.content for g in guardrails],
|
||||
"parameters": context.guardrails_parameters or {}
|
||||
"parameters": context.guardrails_parameters or {},
|
||||
},
|
||||
headers={
|
||||
"Authorization": context.get_guardrailing_authorization(),
|
||||
"Accept": "application/json"
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
@@ -376,11 +376,14 @@ async def check_guardrails(
|
||||
if result.status_code == 401:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format(
|
||||
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER
|
||||
detail=(
|
||||
"The provided Invariant API key is not valid for guardrail checking. "
|
||||
"Please ensure you are using the correct API key or pass an "
|
||||
"alternative API key for guardrail checking specifically via the "
|
||||
f"'{INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER}' header."
|
||||
),
|
||||
)
|
||||
raise Exception(
|
||||
raise Exception( # pylint: disable=broad-exception-raised
|
||||
f"Guardrails check failed: {result.status_code} - {result.text}"
|
||||
)
|
||||
guardrails_result = result.json()
|
||||
@@ -412,7 +415,7 @@ async def check_guardrails(
|
||||
return aggregated_errors
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Failed to verify guardrails: {e}")
|
||||
# make sure runtime errors are also visible in e.g. Explorer
|
||||
return {
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from builtins import print as builtins_print
|
||||
|
||||
os.makedirs(os.path.join(os.path.expanduser("~"), ".invariant"), exist_ok=True)
|
||||
|
||||
+102
-107
@@ -7,7 +7,7 @@ import getpass
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from invariant_sdk.async_client import AsyncClient
|
||||
from invariant_sdk.types.append_messages import AppendMessagesRequest
|
||||
@@ -23,7 +23,6 @@ from gateway.integrations.explorer import (
|
||||
fetch_guardrails_from_explorer,
|
||||
)
|
||||
from gateway.integrations.guardrails import check_guardrails
|
||||
from gateway.mcp.constants import INVARIANT_SESSION_ID_PREFIX
|
||||
|
||||
|
||||
def user_and_host() -> str:
|
||||
@@ -33,6 +32,105 @@ def user_and_host() -> str:
|
||||
|
||||
return f"{username}@{hostname}"
|
||||
|
||||
class McpAttributes(BaseModel):
|
||||
"""
|
||||
A Pydantic model to represent MCP attributes.
|
||||
This can be initialized using HTTP headers for SSE and Streamable transports.
|
||||
This can also be initialized using CLI arguments for the Stdio transport.
|
||||
"""
|
||||
|
||||
push_explorer: bool
|
||||
explorer_dataset: str
|
||||
invariant_api_key: str | None = None
|
||||
verbose: bool | None = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_request_headers(cls, headers: Headers) -> "McpAttributes":
|
||||
"""
|
||||
Create an instance from FastAPI request headers.
|
||||
|
||||
Args:
|
||||
headers: FastAPI Request headers
|
||||
|
||||
Returns:
|
||||
McpAttributes: An instance with values extracted from headers
|
||||
"""
|
||||
# Extract and process header values
|
||||
project_name = headers.get("INVARIANT-PROJECT-NAME")
|
||||
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
|
||||
invariant_api_key = headers.get("INVARIANT-API-KEY")
|
||||
|
||||
# Determine explorer_dataset
|
||||
if project_name:
|
||||
explorer_dataset = project_name
|
||||
else:
|
||||
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
|
||||
|
||||
# Determine push_explorer
|
||||
push_explorer = push_explorer_header == "true"
|
||||
|
||||
# Create and return instance
|
||||
return cls(
|
||||
push_explorer=push_explorer,
|
||||
explorer_dataset=explorer_dataset,
|
||||
invariant_api_key=invariant_api_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, cli_args: list) -> "McpAttributes":
|
||||
"""
|
||||
Create an instance from command line arguments.
|
||||
|
||||
Args:
|
||||
cli_args: List of command line arguments
|
||||
|
||||
Returns:
|
||||
McpAttributes: An instance with values extracted from CLI arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MCP Gateway")
|
||||
parser.add_argument(
|
||||
"--project-name",
|
||||
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
|
||||
type=str,
|
||||
default=f"mcp-capture-{random.randint(1, 100)}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-explorer",
|
||||
help="Enable pushing traces to Invariant Explorer",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
help="Enable verbose logging",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--failure-response-format",
|
||||
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
|
||||
type=str,
|
||||
default="error",
|
||||
)
|
||||
|
||||
config, extra_args = parser.parse_known_args(cli_args)
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
for arg in extra_args:
|
||||
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
|
||||
key, value = arg.split("=")
|
||||
assert key.startswith(
|
||||
"--metadata-"
|
||||
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
|
||||
key = key[len("--metadata-") :]
|
||||
metadata[key] = value
|
||||
|
||||
return cls(
|
||||
push_explorer=config.push_explorer,
|
||||
explorer_dataset=config.project_name,
|
||||
verbose=config.verbose,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class McpSession(BaseModel):
|
||||
"""
|
||||
@@ -41,9 +139,9 @@ class McpSession(BaseModel):
|
||||
|
||||
session_id: str
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
attributes: Optional["McpAttributes"] = None
|
||||
attributes: McpAttributes | None = None
|
||||
id_to_method_mapping: dict[int, str] = Field(default_factory=dict)
|
||||
trace_id: Optional[str] = None
|
||||
trace_id: str | None = None
|
||||
last_trace_length: int = 0
|
||||
annotations: list[dict[str, Any]] = Field(default_factory=list)
|
||||
guardrails: GuardrailRuleSet = Field(
|
||||
@@ -110,9 +208,6 @@ class McpSession(BaseModel):
|
||||
"system_user": user_and_host(),
|
||||
**(self.attributes.metadata or {}),
|
||||
}
|
||||
metadata["is_stateless_http_server"] = self.session_id.startswith(
|
||||
INVARIANT_SESSION_ID_PREFIX
|
||||
)
|
||||
return metadata
|
||||
|
||||
async def get_guardrails_check_result(
|
||||
@@ -264,106 +359,6 @@ class McpSession(BaseModel):
|
||||
return messages
|
||||
|
||||
|
||||
class McpAttributes(BaseModel):
|
||||
"""
|
||||
A Pydantic model to represent MCP attributes.
|
||||
This can be initialized using HTTP headers for SSE and Streamable transports.
|
||||
This can also be initialized using CLI arguments for the Stdio transport.
|
||||
"""
|
||||
|
||||
push_explorer: bool
|
||||
explorer_dataset: str
|
||||
invariant_api_key: Optional[str] = None
|
||||
verbose: Optional[bool] = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_request_headers(cls, headers: Headers) -> "McpAttributes":
|
||||
"""
|
||||
Create an instance from FastAPI request headers.
|
||||
|
||||
Args:
|
||||
headers: FastAPI Request headers
|
||||
|
||||
Returns:
|
||||
McpAttributes: An instance with values extracted from headers
|
||||
"""
|
||||
# Extract and process header values
|
||||
project_name = headers.get("INVARIANT-PROJECT-NAME")
|
||||
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
|
||||
invariant_api_key = headers.get("INVARIANT-API-KEY")
|
||||
|
||||
# Determine explorer_dataset
|
||||
if project_name:
|
||||
explorer_dataset = project_name
|
||||
else:
|
||||
explorer_dataset = f"mcp-capture-{random.randint(1, 100)}"
|
||||
|
||||
# Determine push_explorer
|
||||
push_explorer = push_explorer_header == "true"
|
||||
|
||||
# Create and return instance
|
||||
return cls(
|
||||
push_explorer=push_explorer,
|
||||
explorer_dataset=explorer_dataset,
|
||||
invariant_api_key=invariant_api_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, cli_args: list) -> "McpAttributes":
|
||||
"""
|
||||
Create an instance from command line arguments.
|
||||
|
||||
Args:
|
||||
cli_args: List of command line arguments
|
||||
|
||||
Returns:
|
||||
McpAttributes: An instance with values extracted from CLI arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="MCP Gateway")
|
||||
parser.add_argument(
|
||||
"--project-name",
|
||||
help="Name of the Project from Invariant Explorer where we want to push the MCP traces. The guardrails are pulled from this project.",
|
||||
type=str,
|
||||
default=f"mcp-capture-{random.randint(1, 100)}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-explorer",
|
||||
help="Enable pushing traces to Invariant Explorer",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
help="Enable verbose logging",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--failure-response-format",
|
||||
help="The response format to use to communicate guardrail failures to the client (error: JSON-RPC error response; potentially invisible to the agent, content: JSON-RPC content response, visible to the agent)",
|
||||
type=str,
|
||||
default="error",
|
||||
)
|
||||
|
||||
config, extra_args = parser.parse_known_args(cli_args)
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
for arg in extra_args:
|
||||
assert "=" in arg, f"Invalid extra metadata argument: {arg}"
|
||||
key, value = arg.split("=")
|
||||
assert key.startswith(
|
||||
"--metadata-"
|
||||
), f"Invalid extra metadata argument: {arg}, must start with --metadata-"
|
||||
key = key[len("--metadata-") :]
|
||||
metadata[key] = value
|
||||
|
||||
return cls(
|
||||
push_explorer=config.push_explorer,
|
||||
explorer_dataset=config.project_name,
|
||||
verbose=config.verbose,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class McpSessionsManager:
|
||||
"""
|
||||
A class to manage MCP sessions and their messages.
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
@@ -31,7 +31,7 @@ from gateway.mcp.log import format_errors_in_response
|
||||
from gateway.mcp.mcp_sessions_manager import McpSession, McpSessionsManager
|
||||
|
||||
|
||||
class MCPTransportBase(ABC):
|
||||
class McpTransportBase(ABC):
|
||||
"""
|
||||
Abstract base class for MCP transport strategies.
|
||||
|
||||
@@ -44,16 +44,16 @@ class MCPTransportBase(ABC):
|
||||
|
||||
async def process_outgoing_request(
|
||||
self, session_id: str, request_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
"""
|
||||
Template method for processing outgoing requests to MCP server.
|
||||
|
||||
Returns:
|
||||
Tuple[processed_request_data, is_blocked]
|
||||
tuple[processed_request_data, is_blocked]
|
||||
"""
|
||||
# Update session with request information
|
||||
session = self.session_store.get_session(session_id)
|
||||
MCPTransportBase.update_session_from_request(session, request_data)
|
||||
self.update_session_from_request(session, request_data)
|
||||
|
||||
# Refresh guardrails
|
||||
await session.load_guardrails()
|
||||
@@ -66,19 +66,19 @@ class MCPTransportBase(ABC):
|
||||
|
||||
async def process_incoming_response(
|
||||
self, session_id: str, response_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
"""
|
||||
Template method for processing incoming responses from MCP server.
|
||||
|
||||
Returns:
|
||||
Tuple[processed_response, is_blocked]
|
||||
tuple[processed_response, is_blocked]
|
||||
"""
|
||||
# Update session with server information
|
||||
session = self.session_store.get_session(session_id)
|
||||
MCPTransportBase.update_mcp_server_in_session_metadata(session, response_data)
|
||||
self.update_mcp_server_in_session_metadata(session, response_data)
|
||||
|
||||
# Intercept and apply guardrails to response
|
||||
return await MCPTransportBase.intercept_response(
|
||||
return await McpTransportBase.intercept_response(
|
||||
session_id, self.session_store, response_data
|
||||
)
|
||||
|
||||
@@ -87,20 +87,31 @@ class MCPTransportBase(ABC):
|
||||
method = request_data.get(MCP_METHOD)
|
||||
return method in [MCP_TOOL_CALL, MCP_LIST_TOOLS]
|
||||
|
||||
@staticmethod
|
||||
def _create_jsonrpc_error_response(request_body: dict, message: str) -> dict:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
|
||||
async def _intercept_outgoing_request(
|
||||
self, session_id: str, request_data: dict[str, Any]
|
||||
) -> Tuple[dict[str, Any], bool]:
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
"""Common request interception logic for guardrails."""
|
||||
method = request_data.get(MCP_METHOD)
|
||||
|
||||
interception_result = request_data
|
||||
is_blocked = False
|
||||
if method == MCP_TOOL_CALL:
|
||||
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
|
||||
interception_result, is_blocked = await self.hook_tool_call(
|
||||
session_id, self.session_store, request_data
|
||||
)
|
||||
elif method == MCP_LIST_TOOLS:
|
||||
interception_result, is_blocked = await MCPTransportBase.hook_tool_call(
|
||||
interception_result, is_blocked = await self.hook_tool_call(
|
||||
session_id=session_id,
|
||||
session_store=self.session_store,
|
||||
request_body={
|
||||
@@ -123,7 +134,7 @@ class MCPTransportBase(ABC):
|
||||
"tool_call_id": f"call_{request_body.get('id')}",
|
||||
"content": request_body.get(MCP_RESULT, {}).get("content"),
|
||||
"error": request_body.get(MCP_RESULT, {}).get("error"),
|
||||
"timestamp": MCPTransportBase.generate_timestamp(),
|
||||
"timestamp": McpTransportBase.generate_timestamp(),
|
||||
}
|
||||
|
||||
return message
|
||||
@@ -142,7 +153,7 @@ class MCPTransportBase(ABC):
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [tool_call],
|
||||
"timestamp": MCPTransportBase.generate_timestamp(),
|
||||
"timestamp": McpTransportBase.generate_timestamp(),
|
||||
}
|
||||
return message
|
||||
|
||||
@@ -186,8 +197,8 @@ class MCPTransportBase(ABC):
|
||||
@staticmethod
|
||||
def update_session_from_request(session: McpSession, request_body: dict) -> None:
|
||||
"""Update the MCP client information and request id in the session."""
|
||||
MCPTransportBase.update_mcp_client_info_in_session(session, request_body)
|
||||
MCPTransportBase.update_tool_call_id_in_session(session, request_body)
|
||||
McpTransportBase.update_mcp_client_info_in_session(session, request_body)
|
||||
McpTransportBase.update_tool_call_id_in_session(session, request_body)
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_server_base_url(request: Request) -> str:
|
||||
@@ -198,7 +209,7 @@ class MCPTransportBase(ABC):
|
||||
status_code=400,
|
||||
detail=f"Missing {MCP_SERVER_BASE_URL_HEADER} header",
|
||||
)
|
||||
return MCPTransportBase.convert_localhost_to_docker_host(
|
||||
return McpTransportBase.convert_localhost_to_docker_host(
|
||||
mcp_server_base_url
|
||||
).rstrip("/")
|
||||
|
||||
@@ -233,7 +244,7 @@ class MCPTransportBase(ABC):
|
||||
@staticmethod
|
||||
async def hook_tool_call(
|
||||
session_id: str, session_store: McpSessionsManager, request_body: dict
|
||||
) -> Tuple[dict, bool]:
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Hook to process the request JSON before sending it to the MCP server.
|
||||
|
||||
@@ -243,11 +254,11 @@ class MCPTransportBase(ABC):
|
||||
request_body (dict): The request JSON to be processed.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
|
||||
tuple[dict, bool]: A tuple hook tool call response as a dict and a boolean
|
||||
indicating whether the request was blocked. If the request is blocked, the
|
||||
dict will contain an error message else it will contain the original request.
|
||||
"""
|
||||
message = MCPTransportBase.generate_request_message(request_body)
|
||||
message = McpTransportBase.generate_request_message(request_body)
|
||||
|
||||
# Check for blocking guardrails
|
||||
session = session_store.get_session(session_id)
|
||||
@@ -259,7 +270,7 @@ class MCPTransportBase(ABC):
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and MCPTransportBase.check_if_new_errors(
|
||||
and McpTransportBase.check_if_new_errors(
|
||||
session_id, session_store, guardrails_result
|
||||
)
|
||||
):
|
||||
@@ -269,15 +280,10 @@ class MCPTransportBase(ABC):
|
||||
message=message,
|
||||
guardrails_result=guardrails_result,
|
||||
)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}, True
|
||||
return McpTransportBase._create_jsonrpc_error_response(
|
||||
request_body,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
|
||||
), True
|
||||
|
||||
# Push trace to the explorer
|
||||
await session_store.add_message_to_session(
|
||||
@@ -291,7 +297,7 @@ class MCPTransportBase(ABC):
|
||||
session_store: McpSessionsManager,
|
||||
response_body: dict,
|
||||
is_tools_list=False,
|
||||
) -> Tuple[dict, bool]:
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
|
||||
Hook to process the response JSON after receiving it from the MCP server.
|
||||
@@ -301,7 +307,7 @@ class MCPTransportBase(ABC):
|
||||
response_body (dict): The response JSON to be processed.
|
||||
is_tools_list (bool): Flag to indicate if the response is from a tools/list call.
|
||||
Returns:
|
||||
Tuple[dict, bool]: A tuple containing the processed response JSON
|
||||
tuple[dict, bool]: A tuple containing the processed response JSON
|
||||
and a boolean indicating whether the response was blocked. If the response
|
||||
is blocked, the dict will contain an error message else it will contain the
|
||||
original response.
|
||||
@@ -309,7 +315,7 @@ class MCPTransportBase(ABC):
|
||||
is_blocked = False
|
||||
result = response_body
|
||||
|
||||
message = MCPTransportBase.generate_response_message(result)
|
||||
message = McpTransportBase.generate_response_message(result)
|
||||
|
||||
session = session_store.get_session(session_id)
|
||||
guardrails_result = await session.get_guardrails_check_result(
|
||||
@@ -319,22 +325,17 @@ class MCPTransportBase(ABC):
|
||||
if (
|
||||
guardrails_result
|
||||
and guardrails_result.get("errors", [])
|
||||
and MCPTransportBase.check_if_new_errors(
|
||||
and McpTransportBase.check_if_new_errors(
|
||||
session_id, session_store, guardrails_result
|
||||
)
|
||||
):
|
||||
is_blocked = True
|
||||
|
||||
if not is_tools_list:
|
||||
result = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": response_body.get("id"),
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": INVARIANT_GUARDRAILS_BLOCKED_MESSAGE
|
||||
% guardrails_result["errors"],
|
||||
},
|
||||
}
|
||||
result = McpTransportBase._create_jsonrpc_error_response(
|
||||
response_body,
|
||||
INVARIANT_GUARDRAILS_BLOCKED_MESSAGE % guardrails_result["errors"],
|
||||
)
|
||||
else:
|
||||
# Special error response for tools/list
|
||||
result = {
|
||||
@@ -372,7 +373,7 @@ class MCPTransportBase(ABC):
|
||||
@staticmethod
|
||||
async def intercept_response(
|
||||
session_id: str, session_store: McpSessionsManager, response_body: dict
|
||||
) -> Tuple[dict, bool]:
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Intercept the response and check for guardrails.
|
||||
This function is used to intercept responses and check for guardrails.
|
||||
@@ -386,7 +387,7 @@ class MCPTransportBase(ABC):
|
||||
response_body (dict): The response JSON to be processed.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, bool]: A tuple containing the processed response JSON
|
||||
tuple[dict, bool]: A tuple containing the processed response JSON
|
||||
and a boolean indicating whether the response was blocked.
|
||||
"""
|
||||
session = session_store.get_session(session_id)
|
||||
@@ -400,7 +401,7 @@ class MCPTransportBase(ABC):
|
||||
(
|
||||
intercept_response_result,
|
||||
is_blocked,
|
||||
) = await MCPTransportBase.hook_tool_call_response(
|
||||
) = await McpTransportBase.hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body=response_body,
|
||||
@@ -414,7 +415,7 @@ class MCPTransportBase(ABC):
|
||||
(
|
||||
intercept_response_result,
|
||||
is_blocked,
|
||||
) = await MCPTransportBase.hook_tool_call_response(
|
||||
) = await McpTransportBase.hook_tool_call_response(
|
||||
session_id=session_id,
|
||||
session_store=session_store,
|
||||
response_body={
|
||||
@@ -431,9 +432,9 @@ class MCPTransportBase(ABC):
|
||||
return intercept_response_result, is_blocked
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_session(self, *args, **kwargs) -> str:
|
||||
async def initialize_session(self, **kwargs) -> str:
|
||||
"""Initialize a session for this transport type."""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_communication(self, *args, **kwargs) -> Any:
|
||||
async def handle_communication(self, **kwargs) -> Any:
|
||||
"""Handle the main communication for this transport."""
|
||||
|
||||
+11
-12
@@ -3,20 +3,20 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Any, AsyncGenerator, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from httpx_sse import aconnect_sse, ServerSentEvent
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from gateway.common.constants import CLIENT_TIMEOUT
|
||||
from gateway.common.constants import CLIENT_TIMEOUT, CONTENT_TYPE_EVENT_STREAM
|
||||
from gateway.mcp.constants import MCP_CUSTOM_HEADER_PREFIX, UTF_8
|
||||
from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_transport_base import McpTransportBase
|
||||
|
||||
MCP_SERVER_POST_HEADERS = {
|
||||
"connection",
|
||||
@@ -62,7 +62,7 @@ async def create_sse_transport_and_handle_post(
|
||||
raise HTTPException(status_code=400, detail="Session does not exist")
|
||||
|
||||
request_body = json.loads(await request.body())
|
||||
return await SSETransport(session_store).handle_post_request(
|
||||
return await SseTransport(session_store).handle_post_request(
|
||||
request, session_id, request_body
|
||||
)
|
||||
|
||||
@@ -71,10 +71,10 @@ async def create_sse_transport_and_handle_stream(
|
||||
request: Request, session_store: McpSessionsManager
|
||||
) -> StreamingResponse:
|
||||
"""Integration function for SSE GET route."""
|
||||
return await SSETransport(session_store).handle_sse_stream(request)
|
||||
return await SseTransport(session_store).handle_sse_stream(request)
|
||||
|
||||
|
||||
class SSETransport(MCPTransportBase):
|
||||
class SseTransport(McpTransportBase):
|
||||
"""
|
||||
Server-Sent Events transport implementation for MCP communication.
|
||||
Handles HTTP-based SSE communication with message queuing.
|
||||
@@ -82,12 +82,11 @@ class SSETransport(MCPTransportBase):
|
||||
|
||||
async def initialize_session(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Initialize or get existing SSE session."""
|
||||
session_id: Optional[str] = kwargs.get("session_id", None)
|
||||
session_attributes: Optional[McpAttributes] = kwargs.get(
|
||||
session_id: str | None = kwargs.get("session_id", None)
|
||||
session_attributes: McpAttributes | None = kwargs.get(
|
||||
"session_attributes", None
|
||||
)
|
||||
if session_id and self.session_store.session_exists(session_id):
|
||||
@@ -294,17 +293,17 @@ class SSETransport(MCPTransportBase):
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
media_type=CONTENT_TYPE_EVENT_STREAM,
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
|
||||
)
|
||||
|
||||
async def handle_communication(self, *args, **kwargs) -> StreamingResponse:
|
||||
async def handle_communication(self, **kwargs) -> StreamingResponse:
|
||||
"""Main communication handler for SSE transport."""
|
||||
return await self.handle_sse_stream(kwargs.get("request"))
|
||||
|
||||
async def _handle_endpoint_event(
|
||||
self, sse: ServerSentEvent, sse_header_attributes: McpAttributes
|
||||
) -> Tuple[bytes, str]:
|
||||
) -> tuple[bytes, str]:
|
||||
"""Handle endpoint event and initialize session if needed."""
|
||||
match = re.search(r"session_id=([^&\s]+)", sse.data)
|
||||
session_id = match.group(1) if match else None
|
||||
|
||||
@@ -7,7 +7,6 @@ import platform
|
||||
import select
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gateway.mcp.constants import UTF_8
|
||||
from gateway.mcp.log import mcp_log, MCP_LOG_FILE
|
||||
@@ -15,7 +14,7 @@ from gateway.mcp.mcp_sessions_manager import (
|
||||
McpAttributes,
|
||||
McpSessionsManager,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_transport_base import McpTransportBase
|
||||
|
||||
STATUS_EOF = "eof"
|
||||
STATUS_DATA = "data"
|
||||
@@ -23,7 +22,7 @@ STATUS_WAIT = "wait"
|
||||
mcp_sessions_manager = McpSessionsManager()
|
||||
|
||||
|
||||
class StdioTransport(MCPTransportBase):
|
||||
class StdioTransport(McpTransportBase):
|
||||
"""
|
||||
STDIO transport implementation for MCP communication.
|
||||
Handles subprocess-based communication with stdin/stdout/stderr.
|
||||
@@ -33,7 +32,7 @@ class StdioTransport(MCPTransportBase):
|
||||
super().__init__(session_store)
|
||||
self.mcp_process: subprocess.Popen = None
|
||||
|
||||
async def initialize_session(self, *args, **kwargs) -> str:
|
||||
async def initialize_session(self, **kwargs) -> str:
|
||||
"""Initialize session for stdio transport."""
|
||||
session_attributes: McpAttributes = kwargs.get("session_attributes")
|
||||
session_id = self.generate_session_id()
|
||||
@@ -53,7 +52,7 @@ class StdioTransport(MCPTransportBase):
|
||||
mcp_log(f"Started MCP process with PID: {self.mcp_process.pid}")
|
||||
return self.mcp_process
|
||||
|
||||
async def handle_communication(self, *args, **kwargs) -> None:
|
||||
async def handle_communication(self, **kwargs) -> None:
|
||||
"""Handle stdio communication loop."""
|
||||
session_id: str = kwargs.get("session_id")
|
||||
mcp_process: subprocess.Popen = kwargs.get("mcp_process")
|
||||
@@ -210,7 +209,7 @@ class StdioTransport(MCPTransportBase):
|
||||
|
||||
async def _wait_for_stdin_input(
|
||||
self, loop: asyncio.AbstractEventLoop, stdin_fd: int
|
||||
) -> Tuple[Optional[bytes], str]:
|
||||
) -> tuple[bytes | None, str]:
|
||||
"""Platform-specific implementation to wait for and read input from stdin."""
|
||||
if platform.system() == "Windows":
|
||||
await asyncio.sleep(0.01)
|
||||
@@ -261,7 +260,7 @@ async def create_stdio_transport_and_execute(
|
||||
)
|
||||
|
||||
|
||||
def split_args(args: list[str] = None) -> tuple[list[str], list[str]]:
|
||||
def split_args(args: list[str] | None = None) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Splits CLI arguments into two parts:
|
||||
1. Arguments intended for the MCP gateway (everything before `--exec`)
|
||||
|
||||
+23
-21
@@ -1,14 +1,19 @@
|
||||
"""Gateway service to forward requests to the MCP Streamable HTTP servers"""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from httpx_sse import aconnect_sse
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from gateway.common.constants import CLIENT_TIMEOUT
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
CONTENT_TYPE_HEADER,
|
||||
CONTENT_TYPE_JSON,
|
||||
CONTENT_TYPE_EVENT_STREAM,
|
||||
)
|
||||
from gateway.mcp.constants import (
|
||||
INVARIANT_SESSION_ID_PREFIX,
|
||||
MCP_CUSTOM_HEADER_PREFIX,
|
||||
@@ -18,16 +23,13 @@ from gateway.mcp.mcp_sessions_manager import (
|
||||
McpSessionsManager,
|
||||
McpAttributes,
|
||||
)
|
||||
from gateway.mcp.mcp_transport_base import MCPTransportBase
|
||||
from gateway.mcp.mcp_transport_base import McpTransportBase
|
||||
|
||||
gateway = APIRouter()
|
||||
mcp_sessions_manager = McpSessionsManager()
|
||||
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_SSE = "text/event-stream"
|
||||
CONTENT_TYPE_HEADER = "content-type"
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
MCP_SERVER_POST_DELETE_HEADERS = {
|
||||
MCP_SERVER_POST_AND_DELETE_HEADERS = {
|
||||
"connection",
|
||||
"accept",
|
||||
CONTENT_TYPE_HEADER,
|
||||
@@ -69,7 +71,7 @@ async def mcp_delete_streamable_gateway(request: Request) -> Response:
|
||||
|
||||
async def create_streamable_transport_and_handle_request(
|
||||
request: Request, method: str, session_store: McpSessionsManager
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
) -> Response | StreamingResponse:
|
||||
"""Integration function for streamable routes."""
|
||||
streamable_transport = StreamableTransport(session_store)
|
||||
return await streamable_transport.handle_communication(
|
||||
@@ -77,7 +79,7 @@ async def create_streamable_transport_and_handle_request(
|
||||
)
|
||||
|
||||
|
||||
class StreamableTransport(MCPTransportBase):
|
||||
class StreamableTransport(McpTransportBase):
|
||||
"""
|
||||
Streamable HTTP transport implementation for MCP communication.
|
||||
Handles HTTP POST/GET/DELETE requests with JSON and streaming responses.
|
||||
@@ -85,12 +87,11 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
async def initialize_session(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Initialize streamable HTTP session."""
|
||||
session_id: Optional[str] = kwargs.get("session_id", None)
|
||||
session_attributes: Optional[McpAttributes] = kwargs.get(
|
||||
session_id: str | None = kwargs.get("session_id", None)
|
||||
session_attributes: McpAttributes | None = kwargs.get(
|
||||
"session_attributes", None
|
||||
)
|
||||
is_initialization_request: bool = kwargs.get("is_initialization_request", False)
|
||||
@@ -111,7 +112,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
async def handle_post_request(
|
||||
self, request: Request, request_body: dict[str, Any]
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
) -> Response | StreamingResponse:
|
||||
"""Handle POST request to streamable endpoint."""
|
||||
session_attributes = McpAttributes.from_request_headers(request.headers)
|
||||
session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
@@ -188,7 +189,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type=CONTENT_TYPE_SSE,
|
||||
media_type=CONTENT_TYPE_EVENT_STREAM,
|
||||
headers={"X-Proxied-By": "mcp-gateway", **response_headers},
|
||||
)
|
||||
|
||||
@@ -222,9 +223,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
print(f"[MCP DELETE] Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Request error") from e
|
||||
|
||||
async def handle_communication(
|
||||
self, *args, **kwargs
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
async def handle_communication(self, **kwargs) -> Response | StreamingResponse:
|
||||
"""Main communication handler for streamable transport."""
|
||||
request = kwargs.get("request")
|
||||
method = kwargs.get("method", "POST")
|
||||
@@ -241,7 +240,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
async def _process_non_init_request(
|
||||
self, session_id: str, request_body: dict[str, Any]
|
||||
) -> Optional[Response]:
|
||||
) -> Response | None:
|
||||
"""Process non-initialization requests for guardrails."""
|
||||
processed_request, is_blocked = await self.process_outgoing_request(
|
||||
session_id, request_body
|
||||
@@ -262,7 +261,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
session_id: str,
|
||||
session_attributes: McpAttributes,
|
||||
is_initialization_request: bool,
|
||||
) -> Union[Response, StreamingResponse]:
|
||||
) -> Response | StreamingResponse:
|
||||
"""Forward request to MCP server and handle response."""
|
||||
async with httpx.AsyncClient(timeout=CLIENT_TIMEOUT) as client:
|
||||
try:
|
||||
@@ -385,7 +384,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type=CONTENT_TYPE_SSE,
|
||||
media_type=CONTENT_TYPE_EVENT_STREAM,
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
@@ -395,6 +394,9 @@ class StreamableTransport(MCPTransportBase):
|
||||
"""Update MCP response info in session metadata."""
|
||||
session = self.session_store.get_session(session_id)
|
||||
self.update_mcp_server_in_session_metadata(session, response_body)
|
||||
session.attributes.metadata["is_stateless_http_server"] = session_id.startswith(
|
||||
INVARIANT_SESSION_ID_PREFIX
|
||||
)
|
||||
session.attributes.metadata["server_response_type"] = (
|
||||
"json" if is_json_response else "sse"
|
||||
)
|
||||
@@ -405,7 +407,7 @@ class StreamableTransport(MCPTransportBase):
|
||||
for k, v in request.headers.items():
|
||||
if k.startswith(MCP_CUSTOM_HEADER_PREFIX):
|
||||
filtered_headers[k.removeprefix(MCP_CUSTOM_HEADER_PREFIX)] = v
|
||||
if k.lower() in MCP_SERVER_POST_DELETE_HEADERS and not (
|
||||
if k.lower() in MCP_SERVER_POST_AND_DELETE_HEADERS and not (
|
||||
k.lower() == MCP_SESSION_ID_HEADER
|
||||
and v.startswith(INVARIANT_SESSION_ID_PREFIX)
|
||||
):
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Task utilities for running async functions"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
from typing import Any
|
||||
|
||||
from gateway.mcp.log import MCP_LOG_FILE
|
||||
|
||||
|
||||
def run_task_sync(async_func, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Runs an asynchronous function synchronously in a separate
|
||||
thread with its own event loop. This function blocks the calling
|
||||
thread until completion or timeout (10 seconds).
|
||||
|
||||
Args:
|
||||
async_func: The async function to run
|
||||
*args: Positional arguments to pass to the async function
|
||||
**kwargs: Keyword arguments to pass to the async function
|
||||
|
||||
Returns:
|
||||
Any: The return value of the async function
|
||||
"""
|
||||
|
||||
def run_in_new_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
async_func(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with redirect_stdout(MCP_LOG_FILE):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=10.0)
|
||||
+33
-16
@@ -14,7 +14,12 @@ from gateway.common.config_manager import (
|
||||
GatewayConfigManager,
|
||||
extract_guardrails_from_header,
|
||||
)
|
||||
from gateway.common.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
CONTENT_TYPE_JSON,
|
||||
CONTENT_TYPE_EVENT_STREAM,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
from gateway.common.request_context import RequestContext
|
||||
from gateway.converters.anthropic_to_invariant import (
|
||||
@@ -64,7 +69,7 @@ def validate_headers(x_api_key: str = Header(None)):
|
||||
)
|
||||
async def anthropic_v1_messages_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
):
|
||||
@@ -162,7 +167,7 @@ async def get_guardrails_check_result(
|
||||
async def push_to_explorer(
|
||||
context: RequestContext,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
guardrails_execution_result: dict | None = None,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
@@ -210,15 +215,18 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
self.anthropic_request: httpx.Request = anthropic_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_string: Optional[str] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
self.response: httpx.Response | None = None
|
||||
self.response_string: str | None = None
|
||||
self.response_json: dict[str, Any] | None = None
|
||||
|
||||
# guardrailing response (if any)
|
||||
self.guardrails_execution_result = {}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the first
|
||||
chunk (for input guardrailing).
|
||||
"""
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, action=GuardrailAction.BLOCK, response_json={}
|
||||
@@ -249,8 +257,8 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
Response(
|
||||
content=error_chunk,
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
headers={"content-type": "application/json"},
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
headers={"content-type": CONTENT_TYPE_JSON},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -263,7 +271,10 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=f"Invalid JSON response received from Anthropic: {self.response.text}, got error{e}",
|
||||
detail=(
|
||||
"Invalid JSON response received from Anthropic: "
|
||||
f"{self.response.text}, got error: {e}"
|
||||
),
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
@@ -289,12 +300,15 @@ class InstrumentedAnthropicResponse(InstrumentedResponse):
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=status_code,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
headers=dict(updated_headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Checks guardrails after the response is received, and asynchronously pushes to Explorer."""
|
||||
"""
|
||||
Checks guardrails after the response is received, and asynchronously
|
||||
pushes to Explorer.
|
||||
"""
|
||||
# ensure the response data is available
|
||||
assert self.response is not None, "response is None"
|
||||
assert self.response_json is not None, "response_json is None"
|
||||
@@ -383,7 +397,10 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
self.sse_buffer = "" # Buffer for incomplete events
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
"""
|
||||
Check guardrails in a pipelined fashion, before processing the
|
||||
first chunk (for input guardrailing).
|
||||
"""
|
||||
if self.context.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context,
|
||||
@@ -503,7 +520,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
f"JSON parsing error in event: {e}. Event data: {event_data[:100]}...",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Error processing event: {e}", flush=True)
|
||||
|
||||
# on last stream chunk, run output guardrails
|
||||
@@ -536,7 +553,7 @@ class InstrumentedAnthropicStreamingResponse(InstrumentedStreamingResponse):
|
||||
"""Process the buffer and extract complete SSE events.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], str]: A tuple containing a list of
|
||||
tuple[list[str], str]: A tuple containing a list of
|
||||
complete events and the remaining buffer with incomplete events.
|
||||
"""
|
||||
# Split on double newlines which separate SSE events
|
||||
@@ -582,7 +599,7 @@ async def handle_streaming_response(
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
response.instrumented_event_generator(), media_type="text/event-stream"
|
||||
response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM
|
||||
)
|
||||
|
||||
|
||||
|
||||
+23
-15
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
|
||||
@@ -16,6 +16,8 @@ from gateway.common.config_manager import (
|
||||
)
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
CONTENT_TYPE_JSON,
|
||||
CONTENT_TYPE_EVENT_STREAM,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
@@ -47,7 +49,7 @@ async def gemini_generate_content_gateway(
|
||||
api_version: str,
|
||||
model: str,
|
||||
endpoint: str,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
|
||||
alt: str = Query(
|
||||
None, title="Response Format", description="Set to 'sse' for streaming"
|
||||
),
|
||||
@@ -58,8 +60,10 @@ async def gemini_generate_content_gateway(
|
||||
if endpoint not in ["generateContent", "streamGenerateContent"]:
|
||||
return Response(
|
||||
content="Invalid endpoint - the only endpoints supported are: \
|
||||
/api/v1/gateway/gemini/<version>/models/<model-name>:generateContent or \
|
||||
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:generateContent",
|
||||
/api/v1/gateway/gemini/<version>/models/<model-name>:generateContent \
|
||||
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:generateContent \
|
||||
/api/v1/gateway/gemini/<version>/models/<model-name>:streamGenerateContent or \
|
||||
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:streamGenerateContent",
|
||||
status_code=400,
|
||||
)
|
||||
headers = {
|
||||
@@ -80,7 +84,11 @@ async def gemini_generate_content_gateway(
|
||||
request_json = json.loads(request_body_bytes)
|
||||
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}"
|
||||
gemini_api_url = (
|
||||
f"https://generativelanguage.googleapis.com/"
|
||||
f"{api_version}/models/"
|
||||
f"{model}:{endpoint}"
|
||||
)
|
||||
if alt == "sse":
|
||||
gemini_api_url += "?alt=sse"
|
||||
gemini_request = client.build_request(
|
||||
@@ -139,7 +147,7 @@ class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
}
|
||||
|
||||
# guardrailing execution result (if any)
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
self.guardrails_execution_result: dict[str, Any] | None = None
|
||||
|
||||
def make_refusal(
|
||||
self,
|
||||
@@ -301,7 +309,7 @@ async def stream_response(
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
media_type=CONTENT_TYPE_EVENT_STREAM,
|
||||
)
|
||||
|
||||
|
||||
@@ -407,7 +415,7 @@ async def get_guardrails_check_result(
|
||||
async def push_to_explorer(
|
||||
context: RequestContext,
|
||||
response_json: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
guardrails_execution_result: dict | None = None,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
@@ -456,11 +464,11 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
self.response: httpx.Response | None = None
|
||||
self.response_json: dict[str, Any] | None = None
|
||||
|
||||
# guardrails execution result (if any)
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
self.guardrails_execution_result: dict[str, Any] | None = None
|
||||
|
||||
async def on_start(self):
|
||||
"""
|
||||
@@ -509,9 +517,9 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
Response(
|
||||
content=error_chunk,
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Content-Type": CONTENT_TYPE_JSON,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -539,7 +547,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
|
||||
@@ -582,7 +590,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
)
|
||||
|
||||
+18
-15
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
@@ -16,6 +16,8 @@ from gateway.common.config_manager import (
|
||||
)
|
||||
from gateway.common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
CONTENT_TYPE_JSON,
|
||||
CONTENT_TYPE_EVENT_STREAM,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
|
||||
@@ -60,14 +62,14 @@ def make_cors_response(request: Request, allow_methods: str) -> Response:
|
||||
|
||||
@gateway.options("/{dataset_name}/openai/chat/completions")
|
||||
@gateway.options("/openai/chat/completions")
|
||||
async def openai_chat_completions_options(request: Request, dataset_name: str = None):
|
||||
async def openai_chat_completions_options(request: Request):
|
||||
"""Enables CORS for the OpenAI chat completions endpoint"""
|
||||
return make_cors_response(request, allow_methods="POST")
|
||||
|
||||
|
||||
@gateway.options("/{dataset_name}/openai/models")
|
||||
@gateway.options("/openai/models")
|
||||
async def openai_models_options(request: Request, dataset_name: str = None):
|
||||
async def openai_models_options(request: Request):
|
||||
"""Enables CORS for the OpenAI models endpoint"""
|
||||
return make_cors_response(request, allow_methods="GET")
|
||||
|
||||
@@ -76,7 +78,7 @@ async def openai_models_options(request: Request, dataset_name: str = None):
|
||||
@gateway.get("/openai/models")
|
||||
async def openai_models_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
|
||||
):
|
||||
"""Proxy request to OpenAI /models endpoint"""
|
||||
headers = {
|
||||
@@ -110,7 +112,7 @@ async def openai_models_gateway(
|
||||
)
|
||||
async def openai_chat_completions_gateway(
|
||||
request: Request,
|
||||
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
||||
dataset_name: str | None = None, # This is None if the client doesn't want to push to Explorer
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
) -> Response:
|
||||
@@ -180,7 +182,7 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
self.guardrails_execution_result: dict | None = None
|
||||
|
||||
# merged_response will be updated with the data from the chunks in the stream
|
||||
# At the end of the stream, this will be sent to the explorer
|
||||
@@ -273,7 +275,8 @@ class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse):
|
||||
}
|
||||
)
|
||||
|
||||
# yield an extra error chunk (without preventing the original chunk to go through after)
|
||||
# yield an extra error chunk (without preventing the original
|
||||
# chunk to go through after)
|
||||
return ExtraItem(f"data: {error_chunk}\n\n".encode())
|
||||
|
||||
# push will happen in on_end
|
||||
@@ -324,7 +327,7 @@ async def handle_stream_response(
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
response.instrumented_event_generator(), media_type="text/event-stream"
|
||||
response.instrumented_event_generator(), media_type=CONTENT_TYPE_EVENT_STREAM
|
||||
)
|
||||
|
||||
|
||||
@@ -483,7 +486,7 @@ def create_metadata(
|
||||
async def push_to_explorer(
|
||||
context: RequestContext,
|
||||
merged_response: dict[str, Any],
|
||||
guardrails_execution_result: Optional[dict] = None,
|
||||
guardrails_execution_result: dict | None = None,
|
||||
) -> None:
|
||||
"""Pushes the merged response to the Invariant Explorer"""
|
||||
# Only push the trace to explorer if the message is an end turn message
|
||||
@@ -569,11 +572,11 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
self.open_ai_request: httpx.Request = open_ai_request
|
||||
|
||||
# request outputs
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
self.response: httpx.Response | None = None
|
||||
self.response_json: dict[str, Any] | None = None
|
||||
|
||||
# guardrailing output (if any)
|
||||
self.guardrails_execution_result: Optional[dict] = None
|
||||
self.guardrails_execution_result: dict | None = None
|
||||
|
||||
async def on_start(self):
|
||||
"""
|
||||
@@ -606,7 +609,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
}
|
||||
),
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
),
|
||||
end_of_stream=True,
|
||||
)
|
||||
@@ -634,7 +637,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
|
||||
@@ -686,7 +689,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
media_type=CONTENT_TYPE_JSON,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "invariant-gateway"
|
||||
version = "0.0.5.2"
|
||||
version = "0.0.6"
|
||||
description = "LLM proxy to observe and debug what your AI agents are doing"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
@@ -7,7 +7,6 @@ import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
# Add integration folder (parent) to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -49,7 +48,7 @@ class WeatherAgent:
|
||||
},
|
||||
}
|
||||
|
||||
def get_response(self, messages: List[Dict]) -> List[Dict]:
|
||||
def get_response(self, messages: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Get the response from the agent for a given user query for weather.
|
||||
"""
|
||||
@@ -83,7 +82,7 @@ class WeatherAgent:
|
||||
else:
|
||||
return response_list
|
||||
|
||||
def get_streaming_response(self, messages: List[Dict]) -> List[Dict]:
|
||||
def get_streaming_response(self, messages: list[dict]) -> list[dict]:
|
||||
"""Get streaming response from the agent for a given user query for weather."""
|
||||
response_list = []
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from resources.mcp.sse.client.main import run as mcp_sse_client_run
|
||||
from resources.mcp.stdio.client.main import run as mcp_stdio_client_run
|
||||
from resources.mcp.streamable.client.main import run as mcp_streamable_client_run
|
||||
@@ -12,7 +11,6 @@ import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from datetime import datetime
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
# Taken from docker-compose.test.yml
|
||||
MCP_SSE_SERVER_HOST = "mcp-messenger-sse-server"
|
||||
@@ -706,7 +704,7 @@ async def test_mcp_message_timestamps(
|
||||
"""Test that MCP messages include timestamps"""
|
||||
project_name = "test-mcp-" + str(uuid.uuid4())
|
||||
|
||||
# Run the MCP client and make the tool call
|
||||
|
||||
result = await _invoke_mcp_tool(
|
||||
transport,
|
||||
gateway_url,
|
||||
|
||||
@@ -11,7 +11,7 @@ async def run(
|
||||
gateway_url: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
headers: dict[str, str] = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> types.CallToolResult | types.ListToolsResult:
|
||||
"""
|
||||
Run the MCP client with the given parameters.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import os
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters, types
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -14,7 +14,7 @@ def _get_server_params(
|
||||
project_name: str,
|
||||
server_script_path: str,
|
||||
push_to_explorer: bool,
|
||||
metadata_keys: Optional[dict[str, str]] = None,
|
||||
metadata_keys: dict[str, str] | None = None,
|
||||
) -> StdioServerParameters:
|
||||
args = [
|
||||
"--from",
|
||||
@@ -59,7 +59,7 @@ async def run(
|
||||
push_to_explorer: bool,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
metadata_keys: Optional[dict[str, str]] = None,
|
||||
metadata_keys: dict[str, str] | None = None,
|
||||
) -> types.CallToolResult | types.ListToolsResult:
|
||||
"""
|
||||
Main function to setup the MCP client and server.
|
||||
|
||||
@@ -12,7 +12,7 @@ async def run(
|
||||
gateway_url: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
headers: dict[str, str] = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> types.CallToolResult | types.ListToolsResult:
|
||||
"""
|
||||
Run the MCP client with the given parameters.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from httpx import Client
|
||||
from openai import OpenAI
|
||||
@@ -62,8 +62,8 @@ def get_gemini_client(
|
||||
async def create_dataset(
|
||||
explorer_api_url: str,
|
||||
invariant_authorization: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
dataset_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a dataset in the Explorer API."""
|
||||
client = Client(base_url=explorer_api_url)
|
||||
response = client.post(
|
||||
@@ -85,7 +85,7 @@ async def add_guardrail_to_dataset(
|
||||
policy: str,
|
||||
action: Literal["block", "log"],
|
||||
invariant_authorization: str,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Add a guardrail to a dataset."""
|
||||
client = Client(base_url=explorer_api_url)
|
||||
response = client.post(
|
||||
|
||||
Reference in New Issue
Block a user