fix gemini streamed refusal

This commit is contained in:
Luca Beurer-Kellner
2025-03-28 22:45:18 +01:00
committed by Hemang
parent c2177faaa8
commit cd6c15105f
4 changed files with 471 additions and 124 deletions
+2 -5
View File
@@ -6,13 +6,8 @@ import time
from typing import Any, Dict, List
from functools import wraps
from fastapi.responses import StreamingResponse
import httpx
<<<<<<< HEAD:gateway/integrations/guardails.py
=======
from zmq import IO_THREADS
from common.request_context_data import RequestContextData
>>>>>>> 91684ce (simplify request instrumentation):gateway/integrations/guardrails.py
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
@@ -250,6 +245,8 @@ class InstrumentedStreamingResponse:
yield extra_item.value
# if end_of_stream is True, stop the stream
if extra_item.end_of_stream:
# cancel next task
next_item_task.cancel()
return
# yield item
+335 -114
View File
@@ -2,7 +2,7 @@
import asyncio
import json
from typing import Any, Optional
from typing import Any, Literal, Optional
import httpx
from common.config_manager import GatewayConfig, GatewayConfigManager
@@ -15,6 +15,14 @@ from common.constants import (
from common.authorization import extract_authorization_from_headers
from common.request_context_data import RequestContextData
from converters.gemini_to_invariant import convert_request, convert_response
from integrations.guardrails import (
ExtraItem,
InstrumentedResponse,
InstrumentedStreamingResponse,
Replacement,
preload_guardrails,
check_guardrails,
)
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
from integrations.guardrails import check_guardrails, preload_guardrails
@@ -82,13 +90,169 @@ async def gemini_generate_content_gateway(
client,
gemini_request,
)
response = await client.send(gemini_request)
return await handle_non_streaming_response(
context,
response,
client,
gemini_request,
)
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
# Store the progressively merged response
self.merged_response = {
"candidates": [{"content": {"parts": []}, "finishReason": None}]
}
# guardrailing execution result (if any)
self.guardrails_execution_result: Optional[dict[str, Any]] = None
def make_refusal(
self,
location: Literal["request", "response"],
guardrails_execution_result: dict[str, Any],
) -> dict:
return {
"candidates": [
{
"content": {
"parts": [
{
"text": f"[Invariant] The {location} did not pass the guardrails",
}
],
}
}
],
"error": {
"code": 400,
"message": f"[Invariant] The {location} did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
"promptFeedback": {
"blockReason": "SAFETY",
"block_reason_message": f"[Invariant] The {location} did not pass the guardrails: "
+ json.dumps(guardrails_execution_result),
"safetyRatings": [
{
"category": "HARM_CATEGORY_UNSPECIFIED",
"probability": "HIGH",
"blocked": True,
}
],
},
}
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
self.make_refusal("request", self.guardrails_execution_result)
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return ExtraItem(
f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True
)
async def event_generator(self):
response = await self.client.send(self.gemini_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content.decode("utf-8"))
error_detail = error_json.get("error", "Unknown error from Gemini API")
except json.JSONDecodeError:
error_detail = {"error": "Failed to parse Gemini error response"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
async for chunk in response.aiter_bytes():
yield chunk
async def on_chunk(self, chunk):
chunk_text = chunk.decode().strip()
if not chunk_text:
return
# Parse and update merged_response incrementally
process_chunk_text(self.merged_response, chunk_text)
# runs on the last stream item
if (
self.merged_response.get("candidates", [])
and self.merged_response.get("candidates")[0].get("finishReason", "")
and self.context.config
and self.context.config.guardrails
):
# Block on the guardrails check
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, self.merged_response
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
self.make_refusal("response", self.guardrails_execution_result)
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
return ExtraItem(
value=f"data: {error_chunk}\r\n\r\n".encode(),
# for Gemini we have to end the stream prematurely, as the client SDK
# will not stop streaming when it encounters an error
end_of_stream=True,
)
async def on_end(self):
"""Runs when the stream ends."""
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
self.merged_response,
self.guardrails_execution_result,
)
)
async def stream_response(
context: RequestContextData,
client: httpx.AsyncClient,
@@ -96,76 +260,21 @@ async def stream_response(
) -> Response:
"""Handles streaming the Gemini response to the client"""
response = await client.send(gemini_request, stream=True)
if response.status_code != 200:
error_content = await response.aread()
try:
error_json = json.loads(error_content.decode("utf-8"))
error_detail = error_json.get("error", "Unknown error from Gemini API")
except json.JSONDecodeError:
error_detail = {"error": "Failed to parse Gemini error response"}
raise HTTPException(status_code=response.status_code, detail=error_detail)
response = InstrumentedStreamingGeminiResponse(
context=context,
client=client,
gemini_request=gemini_request,
)
async def event_generator() -> Any:
# Store the progressively merged response
merged_response = {
"candidates": [{"content": {"parts": []}, "finishReason": None}]
}
async for chunk in response.aiter_bytes():
chunk_text = chunk.decode().strip()
if not chunk_text:
continue
# Parse and update merged_response incrementally
process_chunk_text(merged_response, chunk_text)
if (
merged_response.get("candidates", [])
and merged_response.get("candidates")[0].get("finishReason", "")
and context.config
and context.config.guardrails
):
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, merged_response
)
if guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
}
)
# Push annotated trace to the explorer - don't block on its response
if context.dataset_name:
asyncio.create_task(
push_to_explorer(
context,
merged_response,
guardrails_execution_result,
)
)
yield f"data: {error_chunk}\n\n".encode()
return
# Yield chunk immediately to the client
async def event_generator():
async for chunk in response.instrumented_event_generator():
yield chunk
print("chunk", chunk)
if context.dataset_name:
# Push to Explorer - don't block on the response
asyncio.create_task(
push_to_explorer(
context,
merged_response,
)
)
return StreamingResponse(event_generator(), media_type="text/event-stream")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)
def process_chunk_text(
@@ -281,53 +390,165 @@ async def push_to_explorer(
)
class InstrumentedGeminiResponse(InstrumentedResponse):
def __init__(
self,
context: RequestContextData,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
):
super().__init__()
# request data
self.context: RequestContextData = context
self.client: httpx.AsyncClient = client
self.gemini_request: httpx.Request = gemini_request
# response data
self.response: Optional[httpx.Response] = None
self.response_json: Optional[dict[str, Any]] = None
# guardrails execution result (if any)
self.guardrails_execution_result: Optional[dict[str, Any]] = None
async def on_start(self):
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
if self.context.config and self.context.config.guardrails:
self.guardrails_execution_result = await get_guardrails_check_result(
self.context, {}
)
if self.guardrails_execution_result.get("errors", []):
error_chunk = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
"prompt_feedback": {
"blockReason": "SAFETY",
"safetyRatings": [
{
"category": "HARM_CATEGORY_UNSPECIFIED",
"probability": 0.0,
"blocked": True,
}
],
},
}
)
# Push annotated trace to the explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context,
{},
self.guardrails_execution_result,
)
)
# if we find something, we end the stream prematurely (end_of_stream=True)
# and yield an error chunk instead of actually beginning the stream
return Replacement(
Response(
content=error_chunk,
status_code=400,
media_type="application/json",
headers={
"Content-Type": "application/json",
},
)
)
async def request(self):
self.response = await self.client.send(self.gemini_request)
response_string = self.response.text
response_code = self.response.status_code
try:
self.response_json = self.response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=self.response.status_code,
detail="Invalid JSON response received from Gemini API",
) from e
if self.response.status_code != 200:
raise HTTPException(
status_code=self.response.status_code,
detail=self.response_json.get("error", "Unknown error from Gemini API"),
)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(self.response.headers),
)
async def on_end(self):
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
if self.context.config and self.context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
self.context, self.response_json
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
}
)
response_code = 400
if self.context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(
self.context,
self.response_json,
guardrails_execution_result,
)
)
return Replacement(
Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(self.response.headers),
)
)
# Otherwise, also push to Explorer - don't block on its response
if self.context.dataset_name:
asyncio.create_task(
push_to_explorer(
self.context, self.response_json, guardrails_execution_result
)
)
async def handle_non_streaming_response(
context: RequestContextData,
response: httpx.Response,
client: httpx.AsyncClient,
gemini_request: httpx.Request,
) -> Response:
"""Handles non-streaming Gemini responses"""
try:
response_json = response.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=response.status_code,
detail="Invalid JSON response received from Gemini API",
) from e
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=response_json.get("error", "Unknown error from Gemini API"),
)
guardrails_execution_result = {}
response_string = json.dumps(response_json)
response_code = response.status_code
if context.config and context.config.guardrails:
# Block on the guardrails check
guardrails_execution_result = await get_guardrails_check_result(
context, response_json
)
if guardrails_execution_result.get("errors", []):
response_string = json.dumps(
{
"error": {
"code": 400,
"message": "[Invariant] The response did not pass the guardrails",
"details": guardrails_execution_result,
"status": "INVARIANT_GUARDRAILS_VIOLATION",
},
}
)
response_code = 400
if context.dataset_name:
# Push to Explorer - don't block on its response
asyncio.create_task(
push_to_explorer(context, response_json, guardrails_execution_result)
)
return Response(
content=response_string,
status_code=response_code,
media_type="application/json",
headers=dict(response.headers),
response = InstrumentedGeminiResponse(
context=context,
client=client,
gemini_request=gemini_request,
)
return await response.instrumented_request()
@@ -63,8 +63,13 @@ async def test_message_content_guardrail_from_file(
else:
response = client.models.generate_content_stream(**request)
for chunk in response:
assert "Dublin" not in str(chunk)
assert_is_streamed_refusal(
response,
[
"[Invariant] The response did not pass the guardrails",
"Dublin detected in the response",
],
)
if push_to_explorer:
# Wait for the trace to be saved
@@ -172,8 +177,13 @@ async def test_tool_call_guardrail_from_file(
**request,
)
for chunk in response:
assert "Madrid" not in str(chunk)
assert_is_streamed_refusal(
response,
[
"[Invariant] The response did not pass the guardrails",
"get_capital is called with Germany as argument",
],
)
if push_to_explorer:
# Wait for the trace to be saved
@@ -219,3 +229,122 @@ async def test_tool_call_guardrail_from_file(
== "get_capital is called with Germany as argument"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_input_from_guardrail_from_file(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test input guardrail enforcement with Gemini."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = genai.Client(
api_key=os.getenv("GEMINI_API_KEY"),
http_options={
"headers": {
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/gemini",
},
)
request = {
"model": "gemini-2.0-flash",
"contents": "Tell me more about Fight Club.",
"config": {
"maxOutputTokens": 200,
},
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
client.models.generate_content(**request)
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
)
else:
response = client.models.generate_content_stream(**request)
assert_is_streamed_refusal(
response,
[
"[Invariant] The request did not pass the guardrails",
"Users must not mention the magic phrase 'Fight Club'",
],
)
if push_to_explorer:
time.sleep(2)
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 1
assert trace["messages"][0] == {
"role": "user",
"content": [{"type": "text", "text": "Tell me more about Fight Club."}],
}
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"]
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
def is_refusal(chunk):
return (
len(chunk.candidates) == 1
and chunk.candidates[0].content.parts[0].text.startswith("[Invariant]")
and chunk.prompt_feedback is not None
and "BlockedReason.SAFETY" in str(chunk.prompt_feedback)
)
def assert_is_streamed_refusal(response, expected_message_components: list[str]):
"""
Validates that the streamed response contains a refusal at the end (or as only message).
"""
num_chunks = 0
for c in response:
num_chunks += 1
assert num_chunks >= 1, "Expected at least one chunk"
# last chunk must be a refusal
assert is_refusal(c)
for emc in expected_message_components:
assert (
emc in c.model_dump_json()
), f"Expected message component {emc} not found in refusal message: {c.model_dump_json()}"
@@ -330,7 +330,7 @@ async def test_input_from_guardrail_from_file(
trace = trace_response.json()
# in case of input guardrailing, the pushed trace will not contain a response
assert len(trace["messages"]) == 1, "Trace should only contain the user message"
assert len(trace["messages"]) == 1
assert trace["messages"][0] == {
"role": "user",
"content": "Tell me more about Fight Club.",