mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
192 lines
6.2 KiB
Python
192 lines
6.2 KiB
Python
"""Utility functions for the Invariant explorer."""
|
|
|
|
import os
|
|
from typing import Any, Dict, List
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
|
|
from invariant_sdk.async_client import AsyncClient
|
|
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
|
|
from invariant_sdk.types.annotations import AnnotationCreate
|
|
|
|
import httpx
|
|
|
|
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
|
|
|
|
|
def create_annotations_from_guardrails_errors(
|
|
guardrails_errors: List[dict], action: str = "block"
|
|
) -> List[AnnotationCreate]:
|
|
"""Create Explorer annotations from the guardrails errors."""
|
|
annotations = []
|
|
|
|
def _remove_prefixes(ranges: list[str]) -> list[str]:
|
|
"""
|
|
Remove prefixes from the list of ranges.
|
|
|
|
If the ranges are ['messages.2', 'messages.2.content:25-30', 'messages.2.content']
|
|
then this returns ['messages.2.content:25-30'].
|
|
"""
|
|
ranges = sorted(ranges, key=len)
|
|
result = []
|
|
|
|
for i, s in enumerate(ranges):
|
|
is_prefix = False
|
|
for t in ranges[i + 1 :]:
|
|
if t.startswith(s) and t != s:
|
|
is_prefix = True
|
|
break
|
|
if not is_prefix:
|
|
result.append(s)
|
|
|
|
return result
|
|
|
|
for error in guardrails_errors:
|
|
content = error.get("args")[0]
|
|
filtered_ranges = _remove_prefixes(list(error.get("ranges", [])))
|
|
for r in filtered_ranges:
|
|
annotations.append(
|
|
AnnotationCreate(
|
|
content=content,
|
|
address=r,
|
|
extra_metadata={
|
|
"source": "guardrails-error",
|
|
# if included in error, also include information about guardrail source
|
|
**(
|
|
{"guardrail": error.get("guardrail")}
|
|
if error.get("guardrail")
|
|
else {}
|
|
),
|
|
},
|
|
)
|
|
)
|
|
return annotations
|
|
|
|
|
|
def get_explorer_api_url() -> str:
|
|
return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL)
|
|
|
|
|
|
async def push_trace(
|
|
messages: List[List[Dict[str, Any]]],
|
|
dataset_name: str,
|
|
invariant_authorization: str,
|
|
annotations: List[List[AnnotationCreate]] = None,
|
|
metadata: List[Dict[str, Any]] = None,
|
|
) -> PushTracesResponse:
|
|
"""Pushes traces to the dataset on the Invariant Explorer.
|
|
|
|
If a dataset with the given name does not exist, it will be created.
|
|
|
|
Args:
|
|
messages (List[List[Dict[str, Any]]]): List of messages to push.
|
|
dataset_name (str): Name of the dataset.
|
|
invariant_authorization (str): Value of the
|
|
invariant-authorization header.
|
|
|
|
Returns:
|
|
PushTracesResponse: Response containing the trace ID details.
|
|
"""
|
|
# Remove any None values from the messages
|
|
update_messages = [
|
|
[{k: v for k, v in msg.items() if v is not None} for msg in msg_list]
|
|
for msg_list in messages
|
|
]
|
|
request = PushTracesRequest(
|
|
messages=update_messages,
|
|
annotations=annotations,
|
|
dataset=dataset_name,
|
|
metadata=metadata,
|
|
)
|
|
client = AsyncClient(
|
|
api_url=get_explorer_api_url().rstrip("/"),
|
|
api_key=invariant_authorization.split("Bearer ")[1],
|
|
)
|
|
try:
|
|
return await client.push_trace(request)
|
|
except Exception as e:
|
|
print(f"Failed to push trace: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
async def fetch_guardrails_from_explorer(
|
|
dataset_name: str, invariant_authorization: str
|
|
) -> GuardrailRuleSet:
|
|
"""Get the guardrails for the dataset.
|
|
|
|
Returns:
|
|
GuardrailRuleSet: The guardrails for the dataset grouped by their action.
|
|
"""
|
|
|
|
# TODO: Implement a single API in explorer backend which can return
|
|
# dataset details without requiring a username.
|
|
|
|
client = httpx.AsyncClient(
|
|
base_url=get_explorer_api_url().rstrip("/"),
|
|
headers={
|
|
"Authorization": invariant_authorization,
|
|
},
|
|
)
|
|
|
|
# Get the user details.
|
|
user_info_response = await client.get("/api/v1/user/identity")
|
|
if user_info_response.status_code == 401:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid Invariant API key. Please check your API key.",
|
|
)
|
|
elif user_info_response.status_code != 200:
|
|
raise ValueError(
|
|
f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}"
|
|
)
|
|
user_details = user_info_response.json()
|
|
username = user_details["username"]
|
|
|
|
# Get the dataset policies.
|
|
policies_response = await client.get(
|
|
f"/api/v1/dataset/byuser/{username}/{dataset_name}/policy"
|
|
)
|
|
if policies_response.status_code != 200:
|
|
if policies_response.status_code == 404:
|
|
# If the dataset does not exist, return empty guardrails.
|
|
return GuardrailRuleSet(
|
|
blocking_guardrails=[],
|
|
logging_guardrails=[],
|
|
)
|
|
raise ValueError(
|
|
f"Failed to get dataset details from Explorer: {policies_response.status_code}, {policies_response.text}"
|
|
)
|
|
policies_details = policies_response.json()
|
|
guardrails = policies_details.get("policies", [])
|
|
|
|
blocking_guardrails = []
|
|
logging_guardrails = []
|
|
for g in guardrails:
|
|
action = g["action"]
|
|
|
|
if not g["enabled"]:
|
|
# Skip guardrails that are not enabled.
|
|
continue
|
|
|
|
if action not in (GuardrailAction.BLOCK, GuardrailAction.LOG):
|
|
print("[Warning] Skipping unknown guardrail action: ", action)
|
|
continue
|
|
|
|
guardrail = Guardrail(
|
|
id=g["id"],
|
|
name=g["name"],
|
|
content=g["content"],
|
|
action=GuardrailAction(action),
|
|
)
|
|
|
|
if action == GuardrailAction.BLOCK:
|
|
blocking_guardrails.append(guardrail)
|
|
else:
|
|
logging_guardrails.append(guardrail)
|
|
|
|
return GuardrailRuleSet(
|
|
blocking_guardrails=blocking_guardrails,
|
|
logging_guardrails=logging_guardrails,
|
|
)
|