Add push to explorer support for Gemini for non streaming.

This commit is contained in:
Hemang
2025-03-07 09:11:35 +01:00
committed by Hemang Sarkar
parent ef4f7f146b
commit e107be4fea
5 changed files with 108 additions and 73 deletions
+52
View File
@@ -0,0 +1,52 @@
"""Common Authorization functions used in the gateway."""
from typing import Tuple, Optional
from fastapi import HTTPException, Request
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
API_KEYS_SEPARATOR = ";invariant-auth="
def extract_authorization_from_headers(
request: Request, dataset_name: Optional[str], llm_provider_api_key_header: str
) -> Tuple[str, str]:
"""
Extracts the Invariant authorization and LLM Provider API key from the request headers.
In case the user wants to push to Explorer (when dataset_name is not None),
the request headers must contain the Invariant API Key.
The invariant-authorization header contains the Invariant API Key as
"invariant-authorization": "Bearer <Invariant API Key>"
{llm_provider_api_key_header} contains the LLM Provider API Key as
{llm_provider_api_key_header}: "<API Key>"
For some clients, it is not possible to pass a custom header
In such cases, the Invariant API Key is passed as part of the
{llm_provider_api_key_header} with the LLM Provider API Key
The header in that case becomes:
{llm_provider_api_key_header}: "<API Key>;invariant-auth=<Invariant API Key>"
"""
invariant_authorization = request.headers.get(INVARIANT_AUTHORIZATION_HEADER)
llm_provider_api_key = request.headers.get(llm_provider_api_key_header)
if dataset_name:
if invariant_authorization is None:
if llm_provider_api_key is None:
raise HTTPException(
status_code=400, detail="Missing LLM Provider API Key"
)
if API_KEYS_SEPARATOR not in llm_provider_api_key:
raise HTTPException(status_code=400, detail="Missing invariant api key")
# Both the API keys are passed in the llm_provider_api_key_header
api_keys = request.headers.get(llm_provider_api_key_header).split(
API_KEYS_SEPARATOR
)
if len(api_keys) != 2 or not api_keys[1].strip():
raise HTTPException(
status_code=400, detail="Invalid API Key format"
)
invariant_authorization = f"Bearer {api_keys[1].strip()}"
llm_provider_api_key = f"{api_keys[0].strip()}"
return invariant_authorization, llm_provider_api_key
@@ -13,4 +13,3 @@ IGNORED_HEADERS = [
]
CLIENT_TIMEOUT = 60.0
INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization"
+6 -34
View File
@@ -7,12 +7,12 @@ import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.constants import (
from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
INVARIANT_AUTHORIZATION_HEADER,
)
from utils.explorer import push_trace
from common.authorization import extract_authorization_from_headers
gateway = APIRouter()
@@ -56,42 +56,14 @@ async def anthropic_v1_messages_gateway(
}
headers["accept-encoding"] = "identity"
# In case the user wants to push to Explorer, the request must contain the Invariant API Key
# The invariant-authorization header contains the Invariant API Key
# "invariant-authorization": "Bearer <Invariant API Key>"
# The x-api-key header contains the Anthropic API Key
# "x-api-key": "<Anthropic API Key>"
#
# For some clients, it is not possible to pass a custom header
# In such cases, the Invariant API Key is passed as part of the
# x-api-key header with the Anthropic API key.
# The header in that case becomes:
# "x-api-key": "<Anthropic API Key>;invariant-auth=<Invariant API Key>"
invariant_authorization = None
if dataset_name:
if request.headers.get(
INVARIANT_AUTHORIZATION_HEADER
) is None and ";invariant-auth=" not in request.headers.get(
ANTHROPIC_AUTHORIZATION_HEADER
):
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY)
if request.headers.get(INVARIANT_AUTHORIZATION_HEADER):
invariant_authorization = request.headers.get(
INVARIANT_AUTHORIZATION_HEADER
)
else:
header_value = request.headers.get(ANTHROPIC_AUTHORIZATION_HEADER)
api_keys = header_value.split(";invariant-auth=")
invariant_authorization = f"Bearer {api_keys[1].strip()}"
# Update the authorization header to pass the Anthropic API Key
headers[ANTHROPIC_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}"
invariant_authorization, anthopic_api_key = extract_authorization_from_headers(
request, dataset_name, ANTHROPIC_AUTHORIZATION_HEADER
)
headers[ANTHROPIC_AUTHORIZATION_HEADER] = anthopic_api_key
request_body = await request.body()
request_body_json = json.loads(request_body)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
anthropic_request = client.build_request(
"POST",
"https://api.anthropic.com/v1/messages",
+43 -4
View File
@@ -7,10 +7,17 @@ import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
from fastapi.responses import StreamingResponse
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
)
from common.authorization import extract_authorization_from_headers
from utils.explorer import push_trace
gateway = APIRouter()
GEMINI_AUTHORIZATION_HEADER = "x-goog-api-key"
@gateway.post("/gemini/{api_version}/models/{model}:{endpoint}")
async def gemini_generate_content_gateway(
@@ -37,7 +44,15 @@ async def gemini_generate_content_gateway(
}
headers["accept-encoding"] = "identity"
invariant_authorization, gemini_api_key = extract_authorization_from_headers(
request, dataset_name, GEMINI_AUTHORIZATION_HEADER
)
headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key
request_body_bytes = await request.body()
request_body_json = json.loads(request_body_bytes)
print("Here is the request: ", request_body_json)
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}"
if alt == "sse":
@@ -56,7 +71,9 @@ async def gemini_generate_content_gateway(
dataset_name,
)
response = await client.send(gemini_request)
return await handle_non_streaming_response(response, dataset_name)
return await handle_non_streaming_response(
response, dataset_name, request_body_json, invariant_authorization
)
async def stream_response(
@@ -83,6 +100,7 @@ async def stream_response(
continue
# Yield chunk immediately to the client
print("Here is the response chunk: ", chunk)
yield chunk
# Send full merged response to the explorer
@@ -93,13 +111,33 @@ async def stream_response(
return StreamingResponse(event_generator(), media_type="text/event-stream")
async def push_to_explorer(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and the choices from the Gemini response
messages = request_body.get("messages", [])
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=invariant_authorization,
)
async def handle_non_streaming_response(
response: httpx.Response,
dataset_name: Optional[str],
request_body_json: dict[str, Any],
invariant_authorization: Optional[str],
) -> Response:
"""Handles non-streaming Gemini responses"""
try:
json_response = response.json()
print("Here is the response: ", json_response)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
@@ -111,8 +149,9 @@ async def handle_non_streaming_response(
detail=json_response.get("error", "Unknown error from Gemini API"),
)
if dataset_name:
# Push to Explorer
pass
await push_to_explorer(
dataset_name, json_response, request_body_json, invariant_authorization
)
return Response(
content=json.dumps(json_response),
+7 -34
View File
@@ -7,12 +7,12 @@ import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from utils.constants import (
from common.constants import (
CLIENT_TIMEOUT,
IGNORED_HEADERS,
INVARIANT_AUTHORIZATION_HEADER,
)
from utils.explorer import push_trace
from common.authorization import extract_authorization_from_headers
gateway = APIRouter()
@@ -47,43 +47,16 @@ async def openai_chat_completions_gateway(
}
headers["accept-encoding"] = "identity"
invariant_authorization, openai_api_key = extract_authorization_from_headers(
request, dataset_name, OPENAI_AUTHORIZATION_HEADER
)
headers[OPENAI_AUTHORIZATION_HEADER] = openai_api_key
request_body_bytes = await request.body()
request_body_json = json.loads(request_body_bytes)
# Check if the request is for streaming
is_streaming = request_body_json.get("stream", False)
# In case the user wants to push to Explorer, the request must contain the Invariant API Key
# The invariant-authorization header contains the Invariant API Key
# "invariant-authorization": "Bearer <Invariant API Key>"
# The authorization header contains the OpenAI API Key
# "authorization": "<OpenAI API Key>"
#
# For some clients, it is not possible to pass a custom header
# In such cases, the Invariant API Key is passed as part of the
# authorization header with the OpenAI API key.
# The header in that case becomes:
# "authorization": "<OpenAI API Key>;invariant-auth=<Invariant API Key>"
invariant_authorization = None
if dataset_name:
if request.headers.get(
INVARIANT_AUTHORIZATION_HEADER
) is None and ";invariant-auth=" not in request.headers.get(
OPENAI_AUTHORIZATION_HEADER
):
raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY)
if request.headers.get(INVARIANT_AUTHORIZATION_HEADER):
invariant_authorization = request.headers.get(
INVARIANT_AUTHORIZATION_HEADER
)
else:
header_value = request.headers.get(OPENAI_AUTHORIZATION_HEADER)
api_keys = header_value.split(";invariant-auth=")
invariant_authorization = f"Bearer {api_keys[1].strip()}"
# Update the authorization header to pass the OpenAI API Key to the OpenAI API
headers[OPENAI_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}"
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
open_ai_request = client.build_request(
"POST",