mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-22 14:59:41 +02:00
fix gemini streamed refusal
This commit is contained in:
committed by
Hemang
parent
c2177faaa8
commit
cd6c15105f
@@ -6,13 +6,8 @@ import time
|
||||
from typing import Any, Dict, List
|
||||
from functools import wraps
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
<<<<<<< HEAD:gateway/integrations/guardails.py
|
||||
=======
|
||||
from zmq import IO_THREADS
|
||||
from common.request_context_data import RequestContextData
|
||||
>>>>>>> 91684ce (simplify request instrumentation):gateway/integrations/guardrails.py
|
||||
|
||||
DEFAULT_API_URL = "https://explorer.invariantlabs.ai"
|
||||
|
||||
@@ -250,6 +245,8 @@ class InstrumentedStreamingResponse:
|
||||
yield extra_item.value
|
||||
# if end_of_stream is True, stop the stream
|
||||
if extra_item.end_of_stream:
|
||||
# cancel next task
|
||||
next_item_task.cancel()
|
||||
return
|
||||
|
||||
# yield item
|
||||
|
||||
+335
-114
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from common.config_manager import GatewayConfig, GatewayConfigManager
|
||||
@@ -15,6 +15,14 @@ from common.constants import (
|
||||
from common.authorization import extract_authorization_from_headers
|
||||
from common.request_context_data import RequestContextData
|
||||
from converters.gemini_to_invariant import convert_request, convert_response
|
||||
from integrations.guardrails import (
|
||||
ExtraItem,
|
||||
InstrumentedResponse,
|
||||
InstrumentedStreamingResponse,
|
||||
Replacement,
|
||||
preload_guardrails,
|
||||
check_guardrails,
|
||||
)
|
||||
from integrations.explorer import create_annotations_from_guardrails_errors, push_trace
|
||||
from integrations.guardrails import check_guardrails, preload_guardrails
|
||||
|
||||
@@ -82,13 +90,169 @@ async def gemini_generate_content_gateway(
|
||||
client,
|
||||
gemini_request,
|
||||
)
|
||||
response = await client.send(gemini_request)
|
||||
return await handle_non_streaming_response(
|
||||
context,
|
||||
response,
|
||||
client,
|
||||
gemini_request,
|
||||
)
|
||||
|
||||
|
||||
class InstrumentedStreamingGeminiResponse(InstrumentedStreamingResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
# Store the progressively merged response
|
||||
self.merged_response = {
|
||||
"candidates": [{"content": {"parts": []}, "finishReason": None}]
|
||||
}
|
||||
|
||||
# guardrailing execution result (if any)
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
|
||||
def make_refusal(
|
||||
self,
|
||||
location: Literal["request", "response"],
|
||||
guardrails_execution_result: dict[str, Any],
|
||||
) -> dict:
|
||||
return {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": f"[Invariant] The {location} did not pass the guardrails",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": f"[Invariant] The {location} did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
"promptFeedback": {
|
||||
"blockReason": "SAFETY",
|
||||
"block_reason_message": f"[Invariant] The {location} did not pass the guardrails: "
|
||||
+ json.dumps(guardrails_execution_result),
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_UNSPECIFIED",
|
||||
"probability": "HIGH",
|
||||
"blocked": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
self.make_refusal("request", self.guardrails_execution_result)
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return ExtraItem(
|
||||
f"data: {error_chunk}\r\n\r\n".encode(), end_of_stream=True
|
||||
)
|
||||
|
||||
async def event_generator(self):
|
||||
response = await self.client.send(self.gemini_request, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content.decode("utf-8"))
|
||||
error_detail = error_json.get("error", "Unknown error from Gemini API")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to parse Gemini error response"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
async def on_chunk(self, chunk):
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
|
||||
# Parse and update merged_response incrementally
|
||||
process_chunk_text(self.merged_response, chunk_text)
|
||||
|
||||
# runs on the last stream item
|
||||
if (
|
||||
self.merged_response.get("candidates", [])
|
||||
and self.merged_response.get("candidates")[0].get("finishReason", "")
|
||||
and self.context.config
|
||||
and self.context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.merged_response
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
self.make_refusal("response", self.guardrails_execution_result)
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
return ExtraItem(
|
||||
value=f"data: {error_chunk}\r\n\r\n".encode(),
|
||||
# for Gemini we have to end the stream prematurely, as the client SDK
|
||||
# will not stop streaming when it encounters an error
|
||||
end_of_stream=True,
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Runs when the stream ends."""
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.merged_response,
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def stream_response(
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
@@ -96,76 +260,21 @@ async def stream_response(
|
||||
) -> Response:
|
||||
"""Handles streaming the Gemini response to the client"""
|
||||
|
||||
response = await client.send(gemini_request, stream=True)
|
||||
if response.status_code != 200:
|
||||
error_content = await response.aread()
|
||||
try:
|
||||
error_json = json.loads(error_content.decode("utf-8"))
|
||||
error_detail = error_json.get("error", "Unknown error from Gemini API")
|
||||
except json.JSONDecodeError:
|
||||
error_detail = {"error": "Failed to parse Gemini error response"}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
response = InstrumentedStreamingGeminiResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
gemini_request=gemini_request,
|
||||
)
|
||||
|
||||
async def event_generator() -> Any:
|
||||
# Store the progressively merged response
|
||||
merged_response = {
|
||||
"candidates": [{"content": {"parts": []}, "finishReason": None}]
|
||||
}
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
continue
|
||||
|
||||
# Parse and update merged_response incrementally
|
||||
process_chunk_text(merged_response, chunk_text)
|
||||
|
||||
if (
|
||||
merged_response.get("candidates", [])
|
||||
and merged_response.get("candidates")[0].get("finishReason", "")
|
||||
and context.config
|
||||
and context.config.guardrails
|
||||
):
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, merged_response
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_chunk}\n\n".encode()
|
||||
return
|
||||
|
||||
# Yield chunk immediately to the client
|
||||
async def event_generator():
|
||||
async for chunk in response.instrumented_event_generator():
|
||||
yield chunk
|
||||
print("chunk", chunk)
|
||||
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on the response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
context,
|
||||
merged_response,
|
||||
)
|
||||
)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def process_chunk_text(
|
||||
@@ -281,53 +390,165 @@ async def push_to_explorer(
|
||||
)
|
||||
|
||||
|
||||
class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
def __init__(
|
||||
self,
|
||||
context: RequestContextData,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# request data
|
||||
self.context: RequestContextData = context
|
||||
self.client: httpx.AsyncClient = client
|
||||
self.gemini_request: httpx.Request = gemini_request
|
||||
|
||||
# response data
|
||||
self.response: Optional[httpx.Response] = None
|
||||
self.response_json: Optional[dict[str, Any]] = None
|
||||
|
||||
# guardrails execution result (if any)
|
||||
self.guardrails_execution_result: Optional[dict[str, Any]] = None
|
||||
|
||||
async def on_start(self):
|
||||
"""Check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing)."""
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
self.guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, {}
|
||||
)
|
||||
if self.guardrails_execution_result.get("errors", []):
|
||||
error_chunk = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
"prompt_feedback": {
|
||||
"blockReason": "SAFETY",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_UNSPECIFIED",
|
||||
"probability": 0.0,
|
||||
"blocked": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Push annotated trace to the explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
{},
|
||||
self.guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
# if we find something, we end the stream prematurely (end_of_stream=True)
|
||||
# and yield an error chunk instead of actually beginning the stream
|
||||
return Replacement(
|
||||
Response(
|
||||
content=error_chunk,
|
||||
status_code=400,
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def request(self):
|
||||
self.response = await self.client.send(self.gemini_request)
|
||||
|
||||
response_string = self.response.text
|
||||
response_code = self.response.status_code
|
||||
|
||||
try:
|
||||
self.response_json = self.response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail="Invalid JSON response received from Gemini API",
|
||||
) from e
|
||||
if self.response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=self.response.status_code,
|
||||
detail=self.response_json.get("error", "Unknown error from Gemini API"),
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
if self.context.config and self.context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
self.context, self.response_json
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
|
||||
if self.context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context,
|
||||
self.response_json,
|
||||
guardrails_execution_result,
|
||||
)
|
||||
)
|
||||
|
||||
return Replacement(
|
||||
Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(self.response.headers),
|
||||
)
|
||||
)
|
||||
|
||||
# Otherwise, also push to Explorer - don't block on its response
|
||||
if self.context.dataset_name:
|
||||
asyncio.create_task(
|
||||
push_to_explorer(
|
||||
self.context, self.response_json, guardrails_execution_result
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
context: RequestContextData,
|
||||
response: httpx.Response,
|
||||
client: httpx.AsyncClient,
|
||||
gemini_request: httpx.Request,
|
||||
) -> Response:
|
||||
"""Handles non-streaming Gemini responses"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="Invalid JSON response received from Gemini API",
|
||||
) from e
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response_json.get("error", "Unknown error from Gemini API"),
|
||||
)
|
||||
guardrails_execution_result = {}
|
||||
response_string = json.dumps(response_json)
|
||||
response_code = response.status_code
|
||||
|
||||
if context.config and context.config.guardrails:
|
||||
# Block on the guardrails check
|
||||
guardrails_execution_result = await get_guardrails_check_result(
|
||||
context, response_json
|
||||
)
|
||||
if guardrails_execution_result.get("errors", []):
|
||||
response_string = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "[Invariant] The response did not pass the guardrails",
|
||||
"details": guardrails_execution_result,
|
||||
"status": "INVARIANT_GUARDRAILS_VIOLATION",
|
||||
},
|
||||
}
|
||||
)
|
||||
response_code = 400
|
||||
if context.dataset_name:
|
||||
# Push to Explorer - don't block on its response
|
||||
asyncio.create_task(
|
||||
push_to_explorer(context, response_json, guardrails_execution_result)
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=response_string,
|
||||
status_code=response_code,
|
||||
media_type="application/json",
|
||||
headers=dict(response.headers),
|
||||
response = InstrumentedGeminiResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
gemini_request=gemini_request,
|
||||
)
|
||||
|
||||
return await response.instrumented_request()
|
||||
|
||||
@@ -63,8 +63,13 @@ async def test_message_content_guardrail_from_file(
|
||||
|
||||
else:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
for chunk in response:
|
||||
assert "Dublin" not in str(chunk)
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The response did not pass the guardrails",
|
||||
"Dublin detected in the response",
|
||||
],
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Wait for the trace to be saved
|
||||
@@ -172,8 +177,13 @@ async def test_tool_call_guardrail_from_file(
|
||||
**request,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
assert "Madrid" not in str(chunk)
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The response did not pass the guardrails",
|
||||
"get_capital is called with Germany as argument",
|
||||
],
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Wait for the trace to be saved
|
||||
@@ -219,3 +229,122 @@ async def test_tool_call_guardrail_from_file(
|
||||
== "get_capital is called with Germany as argument"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_input_from_guardrail_from_file(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test input guardrail enforcement with Gemini."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
|
||||
client = genai.Client(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
http_options={
|
||||
"headers": {
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
"base_url": f"{gateway_url}/api/v1/gateway/{dataset_name}/gemini"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/gemini",
|
||||
},
|
||||
)
|
||||
|
||||
request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": "Tell me more about Fight Club.",
|
||||
"config": {
|
||||
"maxOutputTokens": 200,
|
||||
},
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(genai.errors.ClientError) as exc_info:
|
||||
client.models.generate_content(**request)
|
||||
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
else:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
|
||||
assert_is_streamed_refusal(
|
||||
response,
|
||||
[
|
||||
"[Invariant] The request did not pass the guardrails",
|
||||
"Users must not mention the magic phrase 'Fight Club'",
|
||||
],
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
time.sleep(2)
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Tell me more about Fight Club."}],
|
||||
}
|
||||
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"]
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
def is_refusal(chunk):
|
||||
return (
|
||||
len(chunk.candidates) == 1
|
||||
and chunk.candidates[0].content.parts[0].text.startswith("[Invariant]")
|
||||
and chunk.prompt_feedback is not None
|
||||
and "BlockedReason.SAFETY" in str(chunk.prompt_feedback)
|
||||
)
|
||||
|
||||
|
||||
def assert_is_streamed_refusal(response, expected_message_components: list[str]):
|
||||
"""
|
||||
Validates that the streamed response contains a refusal at the end (or as only message).
|
||||
"""
|
||||
num_chunks = 0
|
||||
for c in response:
|
||||
num_chunks += 1
|
||||
|
||||
assert num_chunks >= 1, "Expected at least one chunk"
|
||||
# last chunk must be a refusal
|
||||
assert is_refusal(c)
|
||||
|
||||
for emc in expected_message_components:
|
||||
assert (
|
||||
emc in c.model_dump_json()
|
||||
), f"Expected message component {emc} not found in refusal message: {c.model_dump_json()}"
|
||||
|
||||
@@ -330,7 +330,7 @@ async def test_input_from_guardrail_from_file(
|
||||
trace = trace_response.json()
|
||||
|
||||
# in case of input guardrailing, the pushed trace will not contain a response
|
||||
assert len(trace["messages"]) == 1, "Trace should only contain the user message"
|
||||
assert len(trace["messages"]) == 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": "Tell me more about Fight Club.",
|
||||
|
||||
Reference in New Issue
Block a user