From cd6c15105fdad345404e4ef2901aa4d94b110abc Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Fri, 28 Mar 2025 22:45:18 +0100 Subject: [PATCH] fix gemini streamed refusal --- gateway/integrations/guardrails.py | 7 +- gateway/routes/gemini.py | 449 +++++++++++++----- .../guardrails/test_guardrails_gemini.py | 137 +++++- .../guardrails/test_guardrails_open_ai.py | 2 +- 4 files changed, 471 insertions(+), 124 deletions(-) diff --git a/gateway/integrations/guardrails.py b/gateway/integrations/guardrails.py index 481884f..b0f3601 100644 --- a/gateway/integrations/guardrails.py +++ b/gateway/integrations/guardrails.py @@ -6,13 +6,8 @@ import time from typing import Any, Dict, List from functools import wraps -from fastapi.responses import StreamingResponse import httpx -<<<<<<< HEAD:gateway/integrations/guardails.py -======= -from zmq import IO_THREADS from common.request_context_data import RequestContextData ->>>>>>> 91684ce (simplify request instrumentation):gateway/integrations/guardrails.py DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -250,6 +245,8 @@ class InstrumentedStreamingResponse: yield extra_item.value # if end_of_stream is True, stop the stream if extra_item.end_of_stream: + # cancel next task + next_item_task.cancel() return # yield item diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 25f94e4..59d0874 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -2,7 +2,7 @@ import asyncio import json -from typing import Any, Optional +from typing import Any, Literal, Optional import httpx from common.config_manager import GatewayConfig, GatewayConfigManager @@ -15,6 +15,14 @@ from common.constants import ( from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData from converters.gemini_to_invariant import convert_request, convert_response +from integrations.guardrails import ( + ExtraItem, + InstrumentedResponse, + InstrumentedStreamingResponse, + Replacement, + preload_guardrails, + check_guardrails, +) from integrations.explorer import create_annotations_from_guardrails_errors, push_trace from integrations.guardrails import check_guardrails, preload_guardrails @@ -82,13 +90,169 @@ async def gemini_generate_content_gateway( client, gemini_request, ) - response = await client.send(gemini_request) return await handle_non_streaming_response( context, - response, + client, + gemini_request, ) +class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + gemini_request: httpx.Request, + ): + super().__init__() + + # request data + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.gemini_request: httpx.Request = gemini_request + + # Store the progressively merged response + self.merged_response = { + "candidates": [{"content": {"parts": []}, "finishReason": None}] + } + + # guardrailing execution result (if any) + self.guardrails_execution_result: Optional[dict[str, Any]] = None + + def make_refusal( + self, + location: Literal["request", "response"], + guardrails_execution_result: dict[str, Any], + ) -> dict: + return { + "candidates": [ + { + "content": { + "parts": [ + { + "text": f"[Invariant] The {location} did not pass the guardrails", + } + ], + } + } + ], + "error": { + "code": 400, + "message": f"[Invariant] The {location} did not pass the guardrails", + "details": guardrails_execution_result, + "status": "INVARIANT_GUARDRAILS_VIOLATION", + }, + "promptFeedback": { + "blockReason": "SAFETY", + "block_reason_message": f"[Invariant] The {location} did not pass the guardrails: " + + json.dumps(guardrails_execution_result), + "safetyRatings": [ + { + "category": "HARM_CATEGORY_UNSPECIFIED", + "probability": "HIGH", + "blocked": True, + } + ], + }, + } + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, {} + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + self.make_refusal("request", self.guardrails_execution_result) + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return ExtraItem( + f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True + ) + + async def event_generator(self): + response = await self.client.send(self.gemini_request, stream=True) + + if response.status_code != 200: + error_content = await response.aread() + try: + error_json = json.loads(error_content.decode("utf-8")) + error_detail = error_json.get("error", "Unknown error from Gemini API") + except json.JSONDecodeError: + error_detail = {"error": "Failed to parse Gemini error response"} + raise HTTPException(status_code=response.status_code, detail=error_detail) + + async for chunk in response.aiter_bytes(): + yield chunk + + async def on_chunk(self, chunk): + chunk_text = chunk.decode().strip() + if not chunk_text: + return + + # Parse and update merged_response incrementally + process_chunk_text(self.merged_response, chunk_text) + + # runs on the last stream item + if ( + self.merged_response.get("candidates", []) + and self.merged_response.get("candidates")[0].get("finishReason", "") + and self.context.config + and self.context.config.guardrails + ): + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + self.make_refusal("response", self.guardrails_execution_result) + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + return ExtraItem( + value=f"data: {error_chunk}\r\n\r\n".encode(), + # for Gemini we have to end the stream prematurely, as the client SDK + # will not stop streaming when it encounters an error + end_of_stream=True, + ) + + async def on_end(self): + """Runs when the stream ends.""" + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + async def stream_response( context: RequestContextData, client: httpx.AsyncClient, @@ -96,76 +260,21 @@ async def stream_response( ) -> Response: """Handles streaming the Gemini response to the client""" - response = await client.send(gemini_request, stream=True) - if response.status_code != 200: - error_content = await response.aread() - try: - error_json = json.loads(error_content.decode("utf-8")) - error_detail = error_json.get("error", "Unknown error from Gemini API") - except json.JSONDecodeError: - error_detail = {"error": "Failed to parse Gemini error response"} - raise HTTPException(status_code=response.status_code, detail=error_detail) + response = InstrumentedStreamingGeminiResponse( + context=context, + client=client, + gemini_request=gemini_request, + ) - async def event_generator() -> Any: - # Store the progressively merged response - merged_response = { - "candidates": [{"content": {"parts": []}, "finishReason": None}] - } - - async for chunk in response.aiter_bytes(): - chunk_text = chunk.decode().strip() - if not chunk_text: - continue - - # Parse and update merged_response incrementally - process_chunk_text(merged_response, chunk_text) - - if ( - merged_response.get("candidates", []) - and merged_response.get("candidates")[0].get("finishReason", "") - and context.config - and context.config.guardrails - ): - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "error": { - "code": 400, - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - "status": "INVARIANT_GUARDRAILS_VIOLATION", - }, - } - ) - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - yield f"data: {error_chunk}\n\n".encode() - return - - # Yield chunk immediately to the client + async def event_generator(): + async for chunk in response.instrumented_event_generator(): yield chunk + print("chunk", chunk) - if context.dataset_name: - # Push to Explorer - don't block on the response - asyncio.create_task( - push_to_explorer( - context, - merged_response, - ) - ) - - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + ) def process_chunk_text( @@ -281,53 +390,165 @@ async def push_to_explorer( ) +class InstrumentedGeminiResponse(InstrumentedResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + gemini_request: httpx.Request, + ): + super().__init__() + + # request data + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.gemini_request: httpx.Request = gemini_request + + # response data + self.response: Optional[httpx.Response] = None + self.response_json: Optional[dict[str, Any]] = None + + # guardrails execution result (if any) + self.guardrails_execution_result: Optional[dict[str, Any]] = None + + async def on_start(self): + """Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing).""" + if self.context.config and self.context.config.guardrails: + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, {} + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "code": 400, + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + "status": "INVARIANT_GUARDRAILS_VIOLATION", + }, + "prompt_feedback": { + "blockReason": "SAFETY", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_UNSPECIFIED", + "probability": 0.0, + "blocked": True, + } + ], + }, + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + {}, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return Replacement( + Response( + content=error_chunk, + status_code=400, + media_type="application/json", + headers={ + "Content-Type": "application/json", + }, + ) + ) + + async def request(self): + self.response = await self.client.send(self.gemini_request) + + response_string = self.response.text + response_code = self.response.status_code + + try: + self.response_json = self.response.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=self.response.status_code, + detail="Invalid JSON response received from Gemini API", + ) from e + if self.response.status_code != 200: + raise HTTPException( + status_code=self.response.status_code, + detail=self.response_json.get("error", "Unknown error from Gemini API"), + ) + + return Response( + content=response_string, + status_code=response_code, + media_type="application/json", + headers=dict(self.response.headers), + ) + + async def on_end(self): + response_string = json.dumps(self.response_json) + response_code = self.response.status_code + + if self.context.config and self.context.config.guardrails: + # Block on the guardrails check + guardrails_execution_result = await get_guardrails_check_result( + self.context, self.response_json + ) + if guardrails_execution_result.get("errors", []): + response_string = json.dumps( + { + "error": { + "code": 400, + "message": "[Invariant] The response did not pass the guardrails", + "details": guardrails_execution_result, + "status": "INVARIANT_GUARDRAILS_VIOLATION", + }, + } + ) + response_code = 400 + + if self.context.dataset_name: + # Push to Explorer - don't block on its response + asyncio.create_task( + push_to_explorer( + self.context, + self.response_json, + guardrails_execution_result, + ) + ) + + return Replacement( + Response( + content=response_string, + status_code=response_code, + media_type="application/json", + headers=dict(self.response.headers), + ) + ) + + # Otherwise, also push to Explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, self.response_json, guardrails_execution_result + ) + ) + + async def handle_non_streaming_response( context: RequestContextData, - response: httpx.Response, + client: httpx.AsyncClient, + gemini_request: httpx.Request, ) -> Response: """Handles non-streaming Gemini responses""" - try: - response_json = response.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=response.status_code, - detail="Invalid JSON response received from Gemini API", - ) from e - if response.status_code != 200: - raise HTTPException( - status_code=response.status_code, - detail=response_json.get("error", "Unknown error from Gemini API"), - ) - guardrails_execution_result = {} - response_string = json.dumps(response_json) - response_code = response.status_code - if context.config and context.config.guardrails: - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, response_json - ) - if guardrails_execution_result.get("errors", []): - response_string = json.dumps( - { - "error": { - "code": 400, - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - "status": "INVARIANT_GUARDRAILS_VIOLATION", - }, - } - ) - response_code = 400 - if context.dataset_name: - # Push to Explorer - don't block on its response - asyncio.create_task( - push_to_explorer(context, response_json, guardrails_execution_result) - ) - - return Response( - content=response_string, - status_code=response_code, - media_type="application/json", - headers=dict(response.headers), + response = InstrumentedGeminiResponse( + context=context, + client=client, + gemini_request=gemini_request, ) + + return await response.instrumented_request() diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index dfc7845..c463186 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -63,8 +63,13 @@ async def test_message_content_guardrail_from_file( else: response = client.models.generate_content_stream(**request) - for chunk in response: - assert "Dublin" not in str(chunk) + assert_is_streamed_refusal( + response, + [ + "[Invariant] The response did not pass the guardrails", + "Dublin detected in the response", + ], + ) if push_to_explorer: # Wait for the trace to be saved @@ -172,8 +177,13 @@ async def test_tool_call_guardrail_from_file( **request, ) - for chunk in response: - assert "Madrid" not in str(chunk) + assert_is_streamed_refusal( + response, + [ + "[Invariant] The response did not pass the guardrails", + "get_capital is called with Germany as argument", + ], + ) if push_to_explorer: # Wait for the trace to be saved @@ -219,3 +229,122 @@ async def test_tool_call_guardrail_from_file( == "get_capital is called with Germany as argument" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_input_from_guardrail_from_file( + explorer_api_url, gateway_url, do_stream, push_to_explorer +): + """Test input guardrail enforcement with Gemini.""" + if not os.getenv("INVARIANT_API_KEY"): + pytest.fail("No INVARIANT_API_KEY set, failing") + + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + + client = genai.Client( + api_key=os.getenv("GEMINI_API_KEY"), + http_options={ + "headers": { + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + "base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/gemini", + }, + ) + + request = { + "model": "gemini-2.0-flash", + "contents": "Tell me more about Fight Club.", + "config": { + "maxOutputTokens": 200, + }, + } + + if not do_stream: + with pytest.raises(genai.errors.ClientError) as exc_info: + client.models.generate_content(**request) + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value + ) + + else: + response = client.models.generate_content_stream(**request) + + assert_is_streamed_refusal( + response, + [ + "[Invariant] The request did not pass the guardrails", + "Users must not mention the magic phrase 'Fight Club'", + ], + ) + + if push_to_explorer: + time.sleep(2) + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 1 + assert trace["messages"][0] == { + "role": "user", + "content": [{"type": "text", "text": "Tell me more about Fight Club."}], + } + + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] + == "Users must not mention the magic phrase 'Fight Club'" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + ) + + +def is_refusal(chunk): + return ( + len(chunk.candidates) == 1 + and chunk.candidates[0].content.parts[0].text.startswith("[Invariant]") + and chunk.prompt_feedback is not None + and "BlockedReason.SAFETY" in str(chunk.prompt_feedback) + ) + + +def assert_is_streamed_refusal(response, expected_message_components: list[str]): + """ + Validates that the streamed response contains a refusal at the end (or as only message). + """ + num_chunks = 0 + for c in response: + num_chunks += 1 + + assert num_chunks >= 1, "Expected at least one chunk" + # last chunk must be a refusal + assert is_refusal(c) + + for emc in expected_message_components: + assert ( + emc in c.model_dump_json() + ), f"Expected message component {emc} not found in refusal message: {c.model_dump_json()}" diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index cf9e274..acc2f67 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -330,7 +330,7 @@ async def test_input_from_guardrail_from_file( trace = trace_response.json() # in case of input guardrailing, the pushed trace will not contain a response - assert len(trace["messages"]) == 1, "Trace should only contain the user message" + assert len(trace["messages"]) == 1 assert trace["messages"][0] == { "role": "user", "content": "Tell me more about Fight Club.",