mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-26 08:27:47 +02:00
613 lines
22 KiB
Python
613 lines
22 KiB
Python
"""Gateway service to forward requests to the Gemini APIs"""
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any, Literal, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from gateway.common.authorization import extract_authorization_from_headers
|
|
from gateway.common.config_manager import (
|
|
GatewayConfig,
|
|
GatewayConfigManager,
|
|
extract_guardrails_from_header,
|
|
)
|
|
from gateway.common.constants import (
|
|
CLIENT_TIMEOUT,
|
|
IGNORED_HEADERS,
|
|
)
|
|
from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
|
|
from gateway.common.request_context import RequestContext
|
|
from gateway.converters.gemini_to_invariant import convert_request, convert_response
|
|
from gateway.integrations.explorer import (
|
|
create_annotations_from_guardrails_errors,
|
|
fetch_guardrails_from_explorer,
|
|
push_trace,
|
|
)
|
|
from gateway.integrations.guardrails import (
|
|
ExtraItem,
|
|
InstrumentedResponse,
|
|
InstrumentedStreamingResponse,
|
|
Replacement,
|
|
check_guardrails,
|
|
)
|
|
|
|
gateway = APIRouter()
|
|
|
|
GEMINI_AUTHORIZATION_HEADER = "x-goog-api-key"
|
|
GEMINI_AUTHORIZATION_FALLBACK_HEADER = "authorization"
|
|
|
|
|
|
@gateway.post("/gemini/{api_version}/models/{model}:{endpoint}")
|
|
@gateway.post("/{dataset_name}/gemini/{api_version}/models/{model}:{endpoint}")
|
|
async def gemini_generate_content_gateway(
|
|
request: Request,
|
|
api_version: str,
|
|
model: str,
|
|
endpoint: str,
|
|
dataset_name: str = None, # This is None if the client doesn't want to push to Explorer
|
|
alt: str = Query(
|
|
None, title="Response Format", description="Set to 'sse' for streaming"
|
|
),
|
|
config: GatewayConfig = Depends(GatewayConfigManager.get_config), # pylint: disable=unused-argument
|
|
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
|
) -> Response:
|
|
"""Proxy calls to the Gemini GenerateContent API"""
|
|
if endpoint not in ["generateContent", "streamGenerateContent"]:
|
|
return Response(
|
|
content="Invalid endpoint - the only endpoints supported are: \
|
|
/api/v1/gateway/gemini/<version>/models/<model-name>:generateContent or \
|
|
/api/v1/gateway/<dataset-name>/gemini/<version>models/<model-name>:generateContent",
|
|
status_code=400,
|
|
)
|
|
headers = {
|
|
k: v
|
|
for k, v in request.headers.items()
|
|
if k.lower() not in IGNORED_HEADERS + [GEMINI_AUTHORIZATION_FALLBACK_HEADER]
|
|
}
|
|
headers["accept-encoding"] = "identity"
|
|
invariant_authorization, gemini_api_key = extract_authorization_from_headers(
|
|
request,
|
|
dataset_name,
|
|
GEMINI_AUTHORIZATION_HEADER,
|
|
[GEMINI_AUTHORIZATION_FALLBACK_HEADER],
|
|
)
|
|
headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key
|
|
|
|
request_body_bytes = await request.body()
|
|
request_json = json.loads(request_body_bytes)
|
|
|
|
client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT))
|
|
gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}"
|
|
if alt == "sse":
|
|
gemini_api_url += "?alt=sse"
|
|
gemini_request = client.build_request(
|
|
"POST",
|
|
gemini_api_url,
|
|
content=request_body_bytes,
|
|
headers=headers,
|
|
)
|
|
|
|
dataset_guardrails = None
|
|
if dataset_name:
|
|
# Get the guardrails for the dataset
|
|
dataset_guardrails = await fetch_guardrails_from_explorer(
|
|
dataset_name, invariant_authorization
|
|
)
|
|
context = RequestContext.create(
|
|
request_json=request_json,
|
|
dataset_name=dataset_name,
|
|
invariant_authorization=invariant_authorization,
|
|
guardrails=header_guardrails or dataset_guardrails,
|
|
config=config,
|
|
request=request,
|
|
)
|
|
if alt == "sse" or endpoint == "streamGenerateContent":
|
|
return await stream_response(
|
|
context,
|
|
client,
|
|
gemini_request,
|
|
)
|
|
return await handle_non_streaming_response(
|
|
context,
|
|
client,
|
|
gemini_request,
|
|
)
|
|
|
|
|
|
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
|
"""Instrumented streaming response for Gemini API"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
gemini_request: httpx.Request,
|
|
):
|
|
super().__init__()
|
|
|
|
# request data
|
|
self.context: RequestContext = context
|
|
self.client: httpx.AsyncClient = client
|
|
self.gemini_request: httpx.Request = gemini_request
|
|
|
|
# Store the progressively merged response
|
|
self.merged_response = {
|
|
"candidates": [{"content": {"parts": []}, "finishReason": None}]
|
|
}
|
|
|
|
# guardrailing execution result (if any)
|
|
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
|
|
|
def make_refusal(
|
|
self,
|
|
location: Literal["request", "response"],
|
|
guardrails_execution_result: dict[str, Any],
|
|
) -> dict:
|
|
"""Create a refusal response for the given request or response"""
|
|
return {
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{
|
|
"text": f"[Invariant] The {location} did not pass the guardrails",
|
|
}
|
|
],
|
|
}
|
|
}
|
|
],
|
|
"error": {
|
|
"code": 400,
|
|
"message": f"[Invariant] The {location} did not pass the guardrails",
|
|
"details": guardrails_execution_result,
|
|
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
|
},
|
|
"promptFeedback": {
|
|
"blockReason": "SAFETY",
|
|
"block_reason_message": f"[Invariant] The {location} did not pass the guardrails: "
|
|
+ json.dumps(guardrails_execution_result),
|
|
"safetyRatings": [
|
|
{
|
|
"category": "HARM_CATEGORY_UNSPECIFIED",
|
|
"probability": "HIGH",
|
|
"blocked": True,
|
|
}
|
|
],
|
|
},
|
|
}
|
|
|
|
async def on_start(self):
|
|
"""
|
|
Check guardrails in a pipelined fashion, before processing the first chunk
|
|
(for input guardrailing).
|
|
"""
|
|
if self.context.guardrails:
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context, action=GuardrailAction.BLOCK, response_json={}
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
error_chunk = json.dumps(
|
|
self.make_refusal("request", self.guardrails_execution_result)
|
|
)
|
|
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
{},
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
# if we find something, we end the stream prematurely (end_of_stream=True)
|
|
# and yield an error chunk instead of actually beginning the stream
|
|
return ExtraItem(
|
|
f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True
|
|
)
|
|
|
|
async def event_generator(self):
|
|
"""Event generator for streaming responses"""
|
|
response = await self.client.send(self.gemini_request, stream=True)
|
|
|
|
if response.status_code != 200:
|
|
error_content = await response.aread()
|
|
try:
|
|
error_json = json.loads(error_content.decode("utf-8"))
|
|
error_detail = error_json.get("error", "Unknown error from Gemini API")
|
|
except json.JSONDecodeError:
|
|
error_detail = {"error": "Failed to parse Gemini error response"}
|
|
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
|
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
|
|
async def on_chunk(self, chunk):
|
|
"""Processes each chunk of the streaming response"""
|
|
chunk_text = chunk.decode().strip()
|
|
if not chunk_text:
|
|
return
|
|
|
|
# Parse and update merged_response incrementally
|
|
process_chunk_text(self.merged_response, chunk_text)
|
|
|
|
# runs on the last stream item
|
|
if (
|
|
self.merged_response.get("candidates", [])
|
|
and self.merged_response.get("candidates")[0].get("finishReason", "")
|
|
and self.context.guardrails
|
|
):
|
|
# Block on the guardrails check
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context,
|
|
action=GuardrailAction.BLOCK,
|
|
response_json=self.merged_response,
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
error_chunk = json.dumps(
|
|
self.make_refusal("response", self.guardrails_execution_result)
|
|
)
|
|
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
self.merged_response,
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
return ExtraItem(
|
|
value=f"data: {error_chunk}\r\n\r\n".encode(),
|
|
# for Gemini we have to end the stream prematurely, as the client SDK
|
|
# will not stop streaming when it encounters an error
|
|
end_of_stream=True,
|
|
)
|
|
|
|
async def on_end(self):
|
|
"""Runs when the stream ends."""
|
|
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
self.merged_response,
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
|
|
async def stream_response(
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
gemini_request: httpx.Request,
|
|
) -> Response:
|
|
"""Handles streaming the Gemini response to the client"""
|
|
|
|
response = InstrumentedStreamingGeminiResponse(
|
|
context=context,
|
|
client=client,
|
|
gemini_request=gemini_request,
|
|
)
|
|
|
|
async def event_generator():
|
|
async for chunk in response.instrumented_event_generator():
|
|
yield chunk
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
|
|
def process_chunk_text(
|
|
merged_response: dict[str, Any],
|
|
chunk_text: str,
|
|
) -> None:
|
|
"""Processes the chunk text and updates the merged_response to be sent to the explorer"""
|
|
# Split the chunk text into individual JSON strings
|
|
# A single chunk can contain multiple "data: " sections
|
|
for json_string in chunk_text.split("data: "):
|
|
json_string = json_string.replace("data: ", "").strip()
|
|
|
|
if not json_string:
|
|
continue
|
|
|
|
try:
|
|
json_chunk = json.loads(json_string)
|
|
except json.JSONDecodeError:
|
|
print("Warning: Could not parse chunk:", json_string)
|
|
|
|
update_merged_response(merged_response, json_chunk)
|
|
|
|
|
|
def update_merged_response(merged_response: dict[str, Any], chunk_json: dict) -> None:
|
|
"""Updates the merged response incrementally with a new chunk."""
|
|
candidates = chunk_json.get("candidates", [])
|
|
|
|
for candidate in candidates:
|
|
content = candidate.get("content", {})
|
|
parts = content.get("parts", [])
|
|
|
|
for part in parts:
|
|
if "text" in part:
|
|
existing_parts = merged_response["candidates"][0]["content"]["parts"]
|
|
if existing_parts and "text" in existing_parts[-1]:
|
|
existing_parts[-1]["text"] += part["text"]
|
|
else:
|
|
existing_parts.append({"text": part["text"]})
|
|
|
|
if "functionCall" in part:
|
|
merged_response["candidates"][0]["content"]["parts"].append(
|
|
{"functionCall": part["functionCall"]}
|
|
)
|
|
|
|
if "role" in content:
|
|
merged_response["candidates"][0]["content"]["role"] = content["role"]
|
|
|
|
if "finishReason" in candidate:
|
|
merged_response["candidates"][0]["finishReason"] = candidate["finishReason"]
|
|
|
|
if "usageMetadata" in chunk_json:
|
|
merged_response["usageMetadata"] = chunk_json["usageMetadata"]
|
|
if "modelVersion" in chunk_json:
|
|
merged_response["modelVersion"] = chunk_json["modelVersion"]
|
|
|
|
|
|
def create_metadata(
|
|
context: RequestContext, response_json: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""Creates metadata for the trace"""
|
|
metadata = {
|
|
k: v
|
|
for k, v in context.request_json.items()
|
|
if k not in ("systemInstruction", "contents")
|
|
}
|
|
metadata["via_gateway"] = True
|
|
metadata.update(
|
|
{
|
|
key: value
|
|
for key, value in response_json.items()
|
|
if key in ("usageMetadata", "modelVersion")
|
|
}
|
|
)
|
|
return metadata
|
|
|
|
|
|
async def get_guardrails_check_result(
|
|
context: RequestContext, action: GuardrailAction, response_json: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""Get the guardrails check result"""
|
|
# Determine which guardrails to apply based on the action
|
|
guardrails = (
|
|
context.guardrails.logging_guardrails
|
|
if action == GuardrailAction.LOG
|
|
else context.guardrails.blocking_guardrails
|
|
)
|
|
if not guardrails:
|
|
return {}
|
|
|
|
converted_requests = convert_request(context.request_json)
|
|
converted_responses = convert_response(response_json)
|
|
|
|
# Block on the guardrails check
|
|
guardrails_execution_result = await check_guardrails(
|
|
messages=converted_requests + converted_responses,
|
|
guardrails=guardrails,
|
|
context=context,
|
|
)
|
|
return guardrails_execution_result
|
|
|
|
|
|
async def push_to_explorer(
|
|
context: RequestContext,
|
|
response_json: dict[str, Any],
|
|
guardrails_execution_result: Optional[dict] = None,
|
|
) -> None:
|
|
"""Pushes the full trace to the Invariant Explorer"""
|
|
guardrails_execution_result = guardrails_execution_result or {}
|
|
annotations = create_annotations_from_guardrails_errors(
|
|
guardrails_execution_result.get("errors", [])
|
|
)
|
|
|
|
# Execute the logging guardrails before pushing to Explorer
|
|
logging_guardrails_execution_result = await get_guardrails_check_result(
|
|
context,
|
|
action=GuardrailAction.LOG,
|
|
response_json=response_json,
|
|
)
|
|
logging_annotations = create_annotations_from_guardrails_errors(
|
|
logging_guardrails_execution_result.get("errors", [])
|
|
)
|
|
# Update the annotations with the logging guardrails
|
|
annotations.extend(logging_annotations)
|
|
|
|
converted_requests = convert_request(context.request_json)
|
|
converted_responses = convert_response(response_json)
|
|
|
|
_ = await push_trace(
|
|
dataset_name=context.dataset_name,
|
|
messages=[converted_requests + converted_responses],
|
|
invariant_authorization=context.invariant_authorization,
|
|
metadata=[create_metadata(context, response_json)],
|
|
annotations=[annotations] if annotations else None,
|
|
)
|
|
|
|
|
|
class InstrumentedGeminiResponse(InstrumentedResponse):
|
|
"""Instrumented response for Gemini API"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
gemini_request: httpx.Request,
|
|
):
|
|
super().__init__()
|
|
|
|
# request data
|
|
self.context: RequestContext = context
|
|
self.client: httpx.AsyncClient = client
|
|
self.gemini_request: httpx.Request = gemini_request
|
|
|
|
# response data
|
|
self.response: Optional[httpx.Response] = None
|
|
self.response_json: Optional[dict[str, Any]] = None
|
|
|
|
# guardrails execution result (if any)
|
|
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
|
|
|
async def on_start(self):
|
|
"""
|
|
Check guardrails in a pipelined fashion, before processing the first chunk
|
|
(for input guardrailing).
|
|
"""
|
|
if self.context.guardrails:
|
|
self.guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context, action=GuardrailAction.BLOCK, response_json={}
|
|
)
|
|
if self.guardrails_execution_result.get("errors", []):
|
|
error_chunk = json.dumps(
|
|
{
|
|
"error": {
|
|
"code": 400,
|
|
"message": "[Invariant] The request did not pass the guardrails",
|
|
"details": self.guardrails_execution_result,
|
|
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
|
},
|
|
"prompt_feedback": {
|
|
"blockReason": "SAFETY",
|
|
"safetyRatings": [
|
|
{
|
|
"category": "HARM_CATEGORY_UNSPECIFIED",
|
|
"probability": 0.0,
|
|
"blocked": True,
|
|
}
|
|
],
|
|
},
|
|
}
|
|
)
|
|
|
|
# Push annotated trace to the explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
{},
|
|
self.guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
# if we find something, we end the stream prematurely (end_of_stream=True)
|
|
# and yield an error chunk instead of actually beginning the stream
|
|
return Replacement(
|
|
Response(
|
|
content=error_chunk,
|
|
status_code=400,
|
|
media_type="application/json",
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
)
|
|
|
|
async def request(self):
|
|
"""Makes the request to the Gemini API and return the response"""
|
|
self.response = await self.client.send(self.gemini_request)
|
|
|
|
response_string = self.response.text
|
|
response_code = self.response.status_code
|
|
|
|
try:
|
|
self.response_json = self.response.json()
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(
|
|
status_code=self.response.status_code,
|
|
detail="Invalid JSON response received from Gemini API",
|
|
) from e
|
|
if self.response.status_code != 200:
|
|
raise HTTPException(
|
|
status_code=self.response.status_code,
|
|
detail=self.response_json.get("error", "Unknown error from Gemini API"),
|
|
)
|
|
|
|
return Response(
|
|
content=response_string,
|
|
status_code=response_code,
|
|
media_type="application/json",
|
|
headers=dict(self.response.headers),
|
|
)
|
|
|
|
async def on_end(self):
|
|
"""Runs when the request ends."""
|
|
response_string = json.dumps(self.response_json)
|
|
response_code = self.response.status_code
|
|
|
|
if self.context.guardrails:
|
|
# Block on the guardrails check
|
|
guardrails_execution_result = await get_guardrails_check_result(
|
|
self.context,
|
|
action=GuardrailAction.BLOCK,
|
|
response_json=self.response_json,
|
|
)
|
|
if guardrails_execution_result.get("errors", []):
|
|
response_string = json.dumps(
|
|
{
|
|
"error": {
|
|
"code": 400,
|
|
"message": "[Invariant] The response did not pass the guardrails",
|
|
"details": guardrails_execution_result,
|
|
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
|
},
|
|
}
|
|
)
|
|
response_code = 400
|
|
|
|
if self.context.dataset_name:
|
|
# Push to Explorer - don't block on its response
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context,
|
|
self.response_json,
|
|
guardrails_execution_result,
|
|
)
|
|
)
|
|
|
|
return Replacement(
|
|
Response(
|
|
content=response_string,
|
|
status_code=response_code,
|
|
media_type="application/json",
|
|
headers=dict(self.response.headers),
|
|
)
|
|
)
|
|
|
|
# Otherwise, also push to Explorer - don't block on its response
|
|
if self.context.dataset_name:
|
|
asyncio.create_task(
|
|
push_to_explorer(
|
|
self.context, self.response_json, guardrails_execution_result
|
|
)
|
|
)
|
|
|
|
|
|
async def handle_non_streaming_response(
|
|
context: RequestContext,
|
|
client: httpx.AsyncClient,
|
|
gemini_request: httpx.Request,
|
|
) -> Response:
|
|
"""Handles non-streaming Gemini responses"""
|
|
|
|
response = InstrumentedGeminiResponse(
|
|
context=context,
|
|
client=client,
|
|
gemini_request=gemini_request,
|
|
)
|
|
|
|
return await response.instrumented_request()
|