Files
invariant-gateway/gateway/common/request_context.py
Luca Beurer-Kellner e18c6b5bdb Add an option to add extra metadata that is pushed and passed to Guardrails during an MCP session (#47)
* use select() before readline

* support for setting static metadata for MCP sessions

* nest extra mcp metadata in metadata object

* unify session metadata

* extra metadata tests

* use empty object as parameters, if None

* list_tools as tool call

* offset indices in tests

* test: adjust addresses

* mcp: make error reporting configurable

* line logging

* log version

* verbose logging + loud exception failure

* add server and client name to policy get

* append trace even if not pushing

* port tools/list message support to SSE

* use python -m build

* adjust guardrail failure address

* support for blocking tools/list in SSE

* use error-based failure response format by default

* tools/list test

* don't list_tools in stdio connect

* flaky test: handle second possible result in anthropic streaming case

---------

Co-authored-by: knielsen404 <kristian@invariantlabs.ai>
2025-05-19 13:44:37 +02:00

127 lines
4.5 KiB
Python

"""Common Request context data class."""
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import fastapi
from gateway.common.config_manager import GatewayConfig
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from gateway.common.authorization import (
extract_guardrail_service_authorization_from_headers,
)
@dataclass(frozen=True)
class RequestContext:
"""Structured context for a request. Must be created via `RequestContext.create()`."""
request_json: Dict[str, Any]
dataset_name: Optional[str] = None
# authorization to use for invariant service like explorer
invariant_authorization: Optional[str] = None
# authorization to use for invariant guardrailing specifically
guardrail_authorization: Optional[str] = None
# the set of guardrails to enforce for this request
guardrails: Optional[GuardrailRuleSet] = None
config: Dict[str, Any] = None
# extra parameters available as input.<key> during guardrail evaluation
guardrails_parameters: Optional[Dict[str, Any]] = None
_created_via_factory: bool = field(
default=False, init=True, repr=False, compare=False
)
def __post_init__(self):
if not self._created_via_factory:
raise RuntimeError(
"RequestContext must be created using RequestContext.create()"
)
@classmethod
def create(
cls,
request_json: Dict[str, Any],
dataset_name: Optional[str] = None,
invariant_authorization: Optional[str] = None,
guardrails: Optional[GuardrailRuleSet] = None,
config: Optional[GatewayConfig] = None,
request: fastapi.Request = None,
guardrails_parameters: Optional[Dict[str, Any]] = None,
) -> "RequestContext":
"""Creates a new RequestContext instance, applying default guardrails if needed."""
# Convert GatewayConfig to a basic dict, excluding guardrails_from_file
context_config = {
key: value
for key, value in (config.__dict__.items() if config else {})
if key != "guardrails_from_file"
}
# If no guardrails are configured and the config specifies
# guardrails_from_file, use those instead.
if (
(
not guardrails
or (
not guardrails.blocking_guardrails
and not guardrails.logging_guardrails
)
)
and config
and config.guardrails_from_file
):
guardrails = GuardrailRuleSet(
blocking_guardrails=[
Guardrail(
id="guardrails-from-gateway-config-file",
name="guardrails from gateway configuration file",
content=config.guardrails_from_file,
action=GuardrailAction.BLOCK,
)
],
logging_guardrails=[],
)
# if additionally provided, extract separate API key to use with guardrailing service
guardrail_service_authorization = None
if request and (
guardrail_authorization
:= extract_guardrail_service_authorization_from_headers(request)
):
guardrail_service_authorization = guardrail_authorization
return cls(
request_json=request_json,
dataset_name=dataset_name,
invariant_authorization=invariant_authorization,
guardrail_authorization=guardrail_service_authorization,
guardrails=guardrails,
config=context_config,
_created_via_factory=True,
guardrails_parameters=guardrails_parameters
)
def get_guardrailing_authorization(self) -> Optional[str]:
"""
Returns the authorization to use for the guardrailing service.
This can be different from the invariant authorization, but falls back
"to be the same if not explicitly set via header.
See also extract_guardrail_service_authorization_from_headers(...)
"""
return self.guardrail_authorization or self.invariant_authorization
def __repr__(self) -> str:
return (
f"RequestContext("
f"request_json={self.request_json}, "
f"dataset_name={self.dataset_name}, "
f"invariant_authorization=inv-*****{self.invariant_authorization[-4:]}, "
f"guardrails={self.guardrails}, "
f"config={self.config})"
)