Use tenacity.RetryError instead of genai.errors.ClientError for gemini guardrailing errors.

This commit is contained in:
Hemang
2025-06-19 12:54:29 +02:00
committed by Hemang Sarkar
parent 70091b7f53
commit df33199343
2 changed files with 39 additions and 25 deletions
@@ -2,17 +2,17 @@
import os
import sys
import uuid
import time
import uuid
# Add integration folder (parent) to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_gemini_client, create_dataset, add_guardrail_to_dataset
import pytest
import requests
import tenacity
from google import genai
from utils import add_guardrail_to_dataset, create_dataset, get_gemini_client
# Pytest plugins
pytest_plugins = ("pytest_asyncio",)
@@ -46,14 +46,17 @@ async def test_message_content_guardrail_from_file(
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
with pytest.raises(tenacity.RetryError) as exc_info:
response = client.models.generate_content(
**request,
)
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info
)
assert "Dublin detected in the response" in str(exc_info)
original_error = exc_info.value.last_attempt.exception()
assert isinstance(original_error, genai.errors.ClientError)
assert original_error.code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
original_error
)
assert "Dublin detected in the response" in str(original_error)
else:
response = client.models.generate_content_stream(**request)
@@ -148,16 +151,17 @@ async def test_tool_call_guardrail_from_file(
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
with pytest.raises(tenacity.RetryError) as exc_info:
client.models.generate_content(
**request,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info
)
assert "get_capital is called with Germany as argument" in str(exc_info)
original_error = exc_info.value.last_attempt.exception()
assert isinstance(original_error, genai.errors.ClientError)
assert original_error.code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
original_error
)
assert "get_capital is called with Germany as argument" in str(original_error)
else:
response = client.models.generate_content_stream(
@@ -246,14 +250,17 @@ async def test_input_from_guardrail_from_file(
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
with pytest.raises(tenacity.RetryError) as exc_info:
client.models.generate_content(**request)
original_error = exc_info.value.last_attempt.exception()
assert isinstance(original_error, genai.errors.ClientError)
assert original_error.code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
original_error
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
original_error
)
else:
@@ -372,15 +379,18 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
},
}
if not do_stream:
with pytest.raises(genai.errors.ClientError) as exc_info:
with pytest.raises(tenacity.RetryError) as exc_info:
client.models.generate_content(**shrek_request)
original_error = exc_info.value.last_attempt.exception()
assert isinstance(original_error, genai.errors.ClientError)
assert original_error.code == 400
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
original_error
)
# Only the block guardrail should be triggered here
assert "ogre detected in response" in str(exc_info.value)
assert "Fiona detected in response" not in str(exc_info.value)
assert "ogre detected in response" in str(original_error)
assert "Fiona detected in response" not in str(original_error)
else:
response = client.models.generate_content_stream(**shrek_request)
@@ -492,12 +502,15 @@ async def test_preguardrailing_with_guardrails_from_explorer(
],
)
else:
with pytest.raises(genai.errors.ClientError) as exc_info:
with pytest.raises(tenacity.RetryError) as exc_info:
chat_response = client.models.generate_content(**request)
original_error = exc_info.value.last_attempt.exception()
assert isinstance(original_error, genai.errors.ClientError)
assert original_error.code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
original_error
)
assert "pun detected in user message" in str(exc_info.value)
assert "pun detected in user message" in str(original_error)
else:
if do_stream:
response = client.models.generate_content_stream(**request)
+1
View File
@@ -9,4 +9,5 @@ pytest
pytest-asyncio
pytest-timeout
tavily-python
tenacity
uv