mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-06 21:23:55 +02:00
Add push to explorer support for Gemini for non streaming.
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
"""Common Authorization functions used in the gateway."""
|
||||
|
||||
from typing import Tuple, Optional
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
|
||||
API_KEYS_SEPARATOR = ";invariant-auth="
|
||||
|
||||
|
||||
def extract_authorization_from_headers(
|
||||
request: Request, dataset_name: Optional[str], llm_provider_api_key_header: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Extracts the Invariant authorization and LLM Provider API key from the request headers.
|
||||
|
||||
In case the user wants to push to Explorer (when dataset_name is not None),
|
||||
the request headers must contain the Invariant API Key.
|
||||
The invariant-authorization header contains the Invariant API Key as
|
||||
"invariant-authorization": "Bearer <Invariant API Key>"
|
||||
{llm_provider_api_key_header} contains the LLM Provider API Key as
|
||||
{llm_provider_api_key_header}: "<API Key>"
|
||||
|
||||
For some clients, it is not possible to pass a custom header
|
||||
In such cases, the Invariant API Key is passed as part of the
|
||||
{llm_provider_api_key_header} with the LLM Provider API Key
|
||||
The header in that case becomes:
|
||||
{llm_provider_api_key_header}: "<API Key>;invariant-auth=<Invariant API Key>"
|
||||
"""
|
||||
invariant_authorization = request.headers.get(INVARIANT_AUTHORIZATION_HEADER)
|
||||
llm_provider_api_key = request.headers.get(llm_provider_api_key_header)
|
||||
if dataset_name:
|
||||
if invariant_authorization is None:
|
||||
if llm_provider_api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing LLM Provider API Key"
|
||||
)
|
||||
|
||||
if API_KEYS_SEPARATOR not in llm_provider_api_key:
|
||||
raise HTTPException(status_code=400, detail="Missing invariant api key")
|
||||
|
||||
# Both the API keys are passed in the llm_provider_api_key_header
|
||||
api_keys = request.headers.get(llm_provider_api_key_header).split(
|
||||
API_KEYS_SEPARATOR
|
||||
)
|
||||
if len(api_keys) != 2 or not api_keys[1].strip():
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid API Key format"
|
||||
)
|
||||
|
||||
invariant_authorization = f"Bearer {api_keys[1].strip()}"
|
||||
llm_provider_api_key = f"{api_keys[0].strip()}"
|
||||
return invariant_authorization, llm_provider_api_key
|
||||
@@ -13,4 +13,3 @@ IGNORED_HEADERS = [
|
||||
]
|
||||
|
||||
CLIENT_TIMEOUT = 60.0
|
||||
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
|
||||
@@ -7,12 +7,12 @@ import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
from utils.constants import (
|
||||
from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
INVARIANT_AUTHORIZATION_HEADER,
|
||||
)
|
||||
from utils.explorer import push_trace
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -56,42 +56,14 @@ async def anthropic_v1_messages_gateway(
|
||||
}
|
||||
headers["accept-encoding"] = "identity"
|
||||
|
||||
# In case the user wants to push to Explorer, the request must contain the Invariant API Key
|
||||
# The invariant-authorization header contains the Invariant API Key
|
||||
# "invariant-authorization": "Bearer <Invariant API Key>"
|
||||
# The x-api-key header contains the Anthropic API Key
|
||||
# "x-api-key": "<Anthropic API Key>"
|
||||
#
|
||||
# For some clients, it is not possible to pass a custom header
|
||||
# In such cases, the Invariant API Key is passed as part of the
|
||||
# x-api-key header with the Anthropic API key.
|
||||
# The header in that case becomes:
|
||||
# "x-api-key": "<Anthropic API Key>;invariant-auth=<Invariant API Key>"
|
||||
invariant_authorization = None
|
||||
if dataset_name:
|
||||
if request.headers.get(
|
||||
INVARIANT_AUTHORIZATION_HEADER
|
||||
) is None and ";invariant-auth=" not in request.headers.get(
|
||||
ANTHROPIC_AUTHORIZATION_HEADER
|
||||
):
|
||||
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY)
|
||||
if request.headers.get(INVARIANT_AUTHORIZATION_HEADER):
|
||||
invariant_authorization = request.headers.get(
|
||||
INVARIANT_AUTHORIZATION_HEADER
|
||||
)
|
||||
else:
|
||||
header_value = request.headers.get(ANTHROPIC_AUTHORIZATION_HEADER)
|
||||
api_keys = header_value.split(";invariant-auth=")
|
||||
invariant_authorization = f"Bearer {api_keys[1].strip()}"
|
||||
# Update the authorization header to pass the Anthropic API Key
|
||||
headers[ANTHROPIC_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}"
|
||||
invariant_authorization, anthopic_api_key = extract_authorization_from_headers(
|
||||
request, dataset_name, ANTHROPIC_AUTHORIZATION_HEADER
|
||||
)
|
||||
headers[ANTHROPIC_AUTHORIZATION_HEADER] = anthopic_api_key
|
||||
|
||||
request_body = await request.body()
|
||||
|
||||
request_body_json = json.loads(request_body)
|
||||
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
|
||||
anthropic_request = client.build_request(
|
||||
"POST",
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
|
||||
@@ -7,10 +7,17 @@ import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
|
||||
from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
)
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from utils.explorer import push_trace
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
GEMINI_AUTHORIZATION_HEADER = "x-goog-api-key"
|
||||
|
||||
|
||||
@gateway.post("/gemini/{api_version}/models/{model}:{endpoint}")
|
||||
async def gemini_generate_content_gateway(
|
||||
@@ -37,7 +44,15 @@ async def gemini_generate_content_gateway(
|
||||
}
|
||||
headers["accept-encoding"] = "identity"
|
||||
|
||||
invariant_authorization, gemini_api_key = extract_authorization_from_headers(
|
||||
request, dataset_name, GEMINI_AUTHORIZATION_HEADER
|
||||
)
|
||||
headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key
|
||||
|
||||
request_body_bytes = await request.body()
|
||||
request_body_json = json.loads(request_body_bytes)
|
||||
print("Here is the request: ", request_body_json)
|
||||
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}"
|
||||
if alt == "sse":
|
||||
@@ -56,7 +71,9 @@ async def gemini_generate_content_gateway(
|
||||
dataset_name,
|
||||
)
|
||||
response = await client.send(gemini_request)
|
||||
return await handle_non_streaming_response(response, dataset_name)
|
||||
return await handle_non_streaming_response(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
)
|
||||
|
||||
|
||||
async def stream_response(
|
||||
@@ -83,6 +100,7 @@ async def stream_response(
|
||||
continue
|
||||
|
||||
# Yield chunk immediately to the client
|
||||
print("Here is the response chunk: ", chunk)
|
||||
yield chunk
|
||||
|
||||
# Send full merged response to the explorer
|
||||
@@ -93,13 +111,33 @@ async def stream_response(
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def push_to_explorer(
|
||||
dataset_name: str,
|
||||
merged_response: dict[str, Any],
|
||||
request_body: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
# Combine the messages from the request body and the choices from the Gemini response
|
||||
messages = request_body.get("messages", [])
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=invariant_authorization,
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
response: httpx.Response,
|
||||
dataset_name: Optional[str],
|
||||
request_body_json: dict[str, Any],
|
||||
invariant_authorization: Optional[str],
|
||||
) -> Response:
|
||||
"""Handles non-streaming Gemini responses"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
print("Here is the response: ", json_response)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
@@ -111,8 +149,9 @@ async def handle_non_streaming_response(
|
||||
detail=json_response.get("error", "Unknown error from Gemini API"),
|
||||
)
|
||||
if dataset_name:
|
||||
# Push to Explorer
|
||||
pass
|
||||
await push_to_explorer(
|
||||
dataset_name, json_response, request_body_json, invariant_authorization
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=json.dumps(json_response),
|
||||
|
||||
@@ -7,12 +7,12 @@ import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from utils.constants import (
|
||||
from common.constants import (
|
||||
CLIENT_TIMEOUT,
|
||||
IGNORED_HEADERS,
|
||||
INVARIANT_AUTHORIZATION_HEADER,
|
||||
)
|
||||
from utils.explorer import push_trace
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
@@ -47,43 +47,16 @@ async def openai_chat_completions_gateway(
|
||||
}
|
||||
headers["accept-encoding"] = "identity"
|
||||
|
||||
invariant_authorization, openai_api_key = extract_authorization_from_headers(
|
||||
request, dataset_name, OPENAI_AUTHORIZATION_HEADER
|
||||
)
|
||||
headers[OPENAI_AUTHORIZATION_HEADER] = openai_api_key
|
||||
|
||||
request_body_bytes = await request.body()
|
||||
request_body_json = json.loads(request_body_bytes)
|
||||
|
||||
# Check if the request is for streaming
|
||||
is_streaming = request_body_json.get("stream", False)
|
||||
|
||||
# In case the user wants to push to Explorer, the request must contain the Invariant API Key
|
||||
# The invariant-authorization header contains the Invariant API Key
|
||||
# "invariant-authorization": "Bearer <Invariant API Key>"
|
||||
# The authorization header contains the OpenAI API Key
|
||||
# "authorization": "<OpenAI API Key>"
|
||||
#
|
||||
# For some clients, it is not possible to pass a custom header
|
||||
# In such cases, the Invariant API Key is passed as part of the
|
||||
# authorization header with the OpenAI API key.
|
||||
# The header in that case becomes:
|
||||
# "authorization": "<OpenAI API Key>;invariant-auth=<Invariant API Key>"
|
||||
invariant_authorization = None
|
||||
if dataset_name:
|
||||
if request.headers.get(
|
||||
INVARIANT_AUTHORIZATION_HEADER
|
||||
) is None and ";invariant-auth=" not in request.headers.get(
|
||||
OPENAI_AUTHORIZATION_HEADER
|
||||
):
|
||||
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY)
|
||||
|
||||
if request.headers.get(INVARIANT_AUTHORIZATION_HEADER):
|
||||
invariant_authorization = request.headers.get(
|
||||
INVARIANT_AUTHORIZATION_HEADER
|
||||
)
|
||||
else:
|
||||
header_value = request.headers.get(OPENAI_AUTHORIZATION_HEADER)
|
||||
api_keys = header_value.split(";invariant-auth=")
|
||||
invariant_authorization = f"Bearer {api_keys[1].strip()}"
|
||||
# Update the authorization header to pass the OpenAI API Key to the OpenAI API
|
||||
headers[OPENAI_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}"
|
||||
|
||||
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
||||
open_ai_request = client.build_request(
|
||||
"POST",
|
||||
|
||||
Reference in New Issue
Block a user