diff --git a/gateway/integrations/explorer.py b/gateway/integrations/explorer.py index dd15235..5cb3671 100644 --- a/gateway/integrations/explorer.py +++ b/gateway/integrations/explorer.py @@ -3,6 +3,8 @@ import os from typing import Any, Dict, List +from fastapi import HTTPException + from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from invariant_sdk.async_client import AsyncClient from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse @@ -62,6 +64,11 @@ def create_annotations_from_guardrails_errors( return annotations +def get_explorer_api_url() -> str: + return "https://preview-explorer.invariantlabs.ai" + return os.getenv("INVARIANT_API_URL", DEFAULT_API_URL) + + async def push_trace( messages: List[List[Dict[str, Any]]], dataset_name: str, @@ -94,7 +101,7 @@ async def push_trace( metadata=metadata, ) client = AsyncClient( - api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), + api_url=get_explorer_api_url().rstrip("/"), api_key=invariant_authorization.split("Bearer ")[1], ) try: @@ -117,7 +124,7 @@ async def fetch_guardrails_from_explorer( # dataset details without requiring a username. client = httpx.AsyncClient( - base_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL).rstrip("/"), + base_url=get_explorer_api_url().rstrip("/"), headers={ "Authorization": invariant_authorization, }, @@ -125,7 +132,12 @@ async def fetch_guardrails_from_explorer( # Get the user details. user_info_response = await client.get("/api/v1/user/identity") - if user_info_response.status_code != 200: + if user_info_response.status_code == 401: + raise HTTPException( + status_code=401, + detail="Invalid Invariant API key. Please check your API key.", + ) + elif user_info_response.status_code != 200: raise ValueError( f"Failed to get user details from Explorer: {user_info_response.status_code}, {user_info_response.text}" ) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 4bfc9f5..d0f47ac 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -6,9 +6,13 @@ import time from typing import Any, Dict, List from functools import wraps +from fastapi import HTTPException import httpx from common.guardrails import Guardrail from common.request_context import RequestContext +from common.authorization import ( + INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER, +) DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -367,6 +371,13 @@ async def check_guardrails( }, ) if not result.is_success: + if result.status_code == 401: + raise HTTPException( + status_code=401, + detail="The provided Invariant API key is not valid for guardrail checking. Please ensure you are using the correct API key or pass an alternative API key for guardrail checking specifically via the '{}' header.".format( + INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER + ), + ) raise Exception( f"Guardrails check failed: {result.status_code} - {result.text}" ) @@ -397,6 +408,8 @@ async def check_guardrails( ] } return aggregated_errors + except HTTPException as e: + raise e except Exception as e: print(f"Failed to verify guardrails: {e}") # make sure runtime errors are also visible in e.g. Explorer diff --git a/tests/test-client.py b/tests/test-client.py index 388e453..2b7098b 100644 --- a/tests/test-client.py +++ b/tests/test-client.py @@ -18,7 +18,7 @@ openai_client = OpenAI( "Invariant-Authorization": "Bearer " + os.getenv("INVARIANT_API_KEY"), "Invariant-Guardrails": guardrails, }, - base_url="http://localhost:9999/api/v1/gateway/non-streaming/openai", + base_url="http://localhost:8005/api/v1/gateway/non-streaming/openai", ) response = openai_client.chat.completions.create(