extend tests for input guardrailing

This commit is contained in:
Luca Beurer-Kellner
2025-03-28 20:10:54 +01:00
committed by Hemang
parent 2a66582c7c
commit 7f820bd79f
3 changed files with 113 additions and 2 deletions
+1 -1
View File
@@ -500,7 +500,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
Response(
content=json.dumps(
{
"error": "[Invariant] The response did not pass the guardrails",
"error": "[Invariant] The request did not pass the guardrails",
"details": self.guardrails_execution_result,
}
),
@@ -244,3 +244,108 @@ async def test_tool_call_guardrail_from_file(
== "get_capital is called with Germany as argument"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, push_to_explorer",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_input_from_guardrail_from_file(
explorer_api_url, gateway_url, do_stream, push_to_explorer
):
"""Test the message content guardrail."""
if not os.getenv("INVARIANT_API_KEY"):
pytest.fail("No INVARIANT_API_KEY set, failing")
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = OpenAI(
http_client=Client(
headers={
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
},
),
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
if push_to_explorer
else f"{gateway_url}/api/v1/gateway/openai",
)
request = {
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
}
if not do_stream:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value
)
else:
with pytest.raises(APIError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=True,
)
for _ in chat_response:
pass
assert (
"[Invariant] The request did not pass the guardrails"
in exc_info.value.message
)
assert "Users must not mention the magic phrase 'Fight Club'" in str(
exc_info.value.body
)
if push_to_explorer:
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
# in case of input guardrailing, the pushed trace will not contain a response
assert len(trace["messages"]) == 1
assert trace["messages"][0] == {
"role": "user",
"content": "Tell me more about Fight Club.",
}
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"]
== "Users must not mention the magic phrase 'Fight Club'"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@@ -13,4 +13,10 @@ raise "Dublin detected in the response" if:
raise "get_capital is called with Germany as argument" if:
(call: ToolCall)
call is tool:get_capital
call.function.arguments["country_name"] == "Germany"
call.function.arguments["country_name"] == "Germany"
# For input guardrailing specifically
raise "Users must not mention the magic phrase 'Fight Club'" if:
(msg: Message)
msg.role == "user"
"Fight Club" in msg.content