mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-06-08 06:03:53 +02:00
Fix broken Gemini tests.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Test the guardrails from file with the Gemini route."""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -10,7 +12,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import tenacity
|
||||
from google import genai
|
||||
from utils import add_guardrail_to_dataset, create_dataset, get_gemini_client
|
||||
|
||||
@@ -46,17 +47,16 @@ async def test_message_content_guardrail_from_file(
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(tenacity.RetryError) as exc_info:
|
||||
with pytest.raises(genai.errors.ClientError) as e:
|
||||
response = client.models.generate_content(
|
||||
**request,
|
||||
)
|
||||
original_error = exc_info.value.last_attempt.exception()
|
||||
assert isinstance(original_error, genai.errors.ClientError)
|
||||
assert original_error.code == 400
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
original_error
|
||||
assert e._excinfo[1].code == 400
|
||||
assert (
|
||||
"[Invariant] The response did not pass the guardrails"
|
||||
in e._excinfo[1].message
|
||||
)
|
||||
assert "Dublin detected in the response" in str(original_error)
|
||||
assert "Dublin detected in the response" in str(e._excinfo[1])
|
||||
|
||||
else:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
@@ -151,17 +151,16 @@ async def test_tool_call_guardrail_from_file(
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(tenacity.RetryError) as exc_info:
|
||||
with pytest.raises(genai.errors.ClientError) as e:
|
||||
client.models.generate_content(
|
||||
**request,
|
||||
)
|
||||
original_error = exc_info.value.last_attempt.exception()
|
||||
assert isinstance(original_error, genai.errors.ClientError)
|
||||
assert original_error.code == 400
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
original_error
|
||||
assert e._excinfo[1].code == 400
|
||||
assert (
|
||||
"[Invariant] The response did not pass the guardrails"
|
||||
in e._excinfo[1].message
|
||||
)
|
||||
assert "get_capital is called with Germany as argument" in str(original_error)
|
||||
assert "get_capital is called with Germany as argument" in str(e._excinfo[1])
|
||||
|
||||
else:
|
||||
response = client.models.generate_content_stream(
|
||||
@@ -250,17 +249,15 @@ async def test_input_from_guardrail_from_file(
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(tenacity.RetryError) as exc_info:
|
||||
with pytest.raises(genai.errors.ClientError) as e:
|
||||
client.models.generate_content(**request)
|
||||
|
||||
original_error = exc_info.value.last_attempt.exception()
|
||||
assert isinstance(original_error, genai.errors.ClientError)
|
||||
assert original_error.code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
original_error
|
||||
assert e._excinfo[1].code == 400
|
||||
assert (
|
||||
"[Invariant] The request did not pass the guardrails"
|
||||
in e._excinfo[1].message
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
original_error
|
||||
e._excinfo[1]
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -379,18 +376,16 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
||||
},
|
||||
}
|
||||
if not do_stream:
|
||||
with pytest.raises(tenacity.RetryError) as exc_info:
|
||||
with pytest.raises(genai.errors.ClientError) as e:
|
||||
client.models.generate_content(**shrek_request)
|
||||
|
||||
original_error = exc_info.value.last_attempt.exception()
|
||||
assert isinstance(original_error, genai.errors.ClientError)
|
||||
assert original_error.code == 400
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
original_error
|
||||
assert e._excinfo[1].code == 400
|
||||
assert (
|
||||
"[Invariant] The response did not pass the guardrails"
|
||||
in e._excinfo[1].message
|
||||
)
|
||||
# Only the block guardrail should be triggered here
|
||||
assert "ogre detected in response" in str(original_error)
|
||||
assert "Fiona detected in response" not in str(original_error)
|
||||
assert "ogre detected in response" in str(e._excinfo[1])
|
||||
assert "Fiona detected in response" not in str(e._excinfo[1])
|
||||
else:
|
||||
response = client.models.generate_content_stream(**shrek_request)
|
||||
|
||||
@@ -502,15 +497,14 @@ async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
],
|
||||
)
|
||||
else:
|
||||
with pytest.raises(tenacity.RetryError) as exc_info:
|
||||
with pytest.raises(genai.errors.ClientError) as e:
|
||||
chat_response = client.models.generate_content(**request)
|
||||
original_error = exc_info.value.last_attempt.exception()
|
||||
assert isinstance(original_error, genai.errors.ClientError)
|
||||
assert original_error.code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
original_error
|
||||
assert e._excinfo[1].code == 400
|
||||
assert (
|
||||
"[Invariant] The request did not pass the guardrails"
|
||||
in e._excinfo[1].message
|
||||
)
|
||||
assert "pun detected in user message" in str(original_error)
|
||||
assert "pun detected in user message" in str(e._excinfo[1])
|
||||
else:
|
||||
if do_stream:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
|
||||
@@ -9,5 +9,4 @@ pytest
|
||||
pytest-asyncio
|
||||
pytest-timeout
|
||||
tavily-python
|
||||
tenacity
|
||||
uv
|
||||
Reference in New Issue
Block a user