Add preguardrailing tests for guardrails pulled from explorer.

This commit is contained in:
Hemang
2025-04-02 11:25:39 +02:00
committed by Hemang Sarkar
parent 55f0f741c0
commit f3a56e1e43
5 changed files with 379 additions and 22 deletions
+1
View File
@@ -536,6 +536,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
)
async def on_end(self):
"""Runs when the request ends."""
response_string = json.dumps(self.response_json)
response_code = self.response.status_code
+16 -17
View File
@@ -432,27 +432,26 @@ async def push_to_explorer(
# or if the guardrails check returned errors.
guardrails_execution_result = guardrails_execution_result or {}
guardrails_errors = guardrails_execution_result.get("errors", [])
if guardrails_errors or not (
annotations = create_annotations_from_guardrails_errors(
guardrails_errors, action="block"
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
if annotations or not (
merged_response.get("choices")
and merged_response["choices"][0].get("finish_reason")
not in FINISH_REASON_TO_PUSH_TRACE
):
annotations = create_annotations_from_guardrails_errors(
guardrails_errors, action="block"
)
# Execute the logging guardrails before pushing to Explorer
logging_guardrails_execution_result = await get_guardrails_check_result(
context,
action=GuardrailAction.LOG,
response_json=merged_response,
)
logging_annotations = create_annotations_from_guardrails_errors(
logging_guardrails_execution_result.get("errors", []), action="log"
)
# Update the annotations with the logging guardrails
annotations.extend(logging_annotations)
# Combine the messages from the request body and the choices from the OpenAI response
messages = list(context.request_json.get("messages", []))
messages += [choice["message"] for choice in merged_response.get("choices", [])]
@@ -316,7 +316,10 @@ async def test_input_from_guardrail_from_file(
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
)
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize("do_stream", [True, False])
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
"""Test that the guardrails from the explorer work."""
@@ -461,3 +464,127 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
)
@pytest.mark.skipif(
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
)
@pytest.mark.parametrize(
"do_stream, is_block_action",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_preguardrailing_with_guardrails_from_explorer(
explorer_api_url, gateway_url, do_stream, is_block_action
):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
client = get_anthropic_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
action="block" if is_block_action else "log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
user_prompt = "Tell me a one sentence pun."
request = {
"model": "claude-3-5-sonnet-20241022",
"messages": [
{
"role": "user",
"content": user_prompt,
}
],
"max_tokens": 100,
}
if is_block_action:
if do_stream:
with pytest.raises(APIStatusError) as exc_info:
chat_response = client.messages.create(
**request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
else:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.messages.create(
**request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
_ = client.messages.create(
**request,
stream=True,
)
else:
_ = client.messages.create(
**request,
stream=False,
)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2 if not is_block_action else 1
assert trace["messages"][0] == {
"role": "user",
"content": user_prompt,
}
if not is_block_action:
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
if is_block_action
else "log"
)
@@ -444,6 +444,118 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
)
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, is_block_action",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_preguardrailing_with_guardrails_from_explorer(
explorer_api_url, gateway_url, do_stream, is_block_action
):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
client = get_gemini_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
action="block" if is_block_action else "log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
user_prompt = "Tell me a one sentence pun."
request = {
"model": "gemini-2.0-flash",
"contents": user_prompt,
"config": {
"maxOutputTokens": 100,
},
}
if is_block_action:
if do_stream:
chat_response = client.models.generate_content_stream(**request)
assert_is_streamed_refusal(
chat_response,
[
"[Invariant] The request did not pass the guardrails",
"pun detected in user message",
],
)
else:
with pytest.raises(genai.errors.ClientError) as exc_info:
chat_response = client.models.generate_content(**request)
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
response = client.models.generate_content_stream(**request)
for _ in response:
pass
else:
_ = client.models.generate_content(**request)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 2 if not is_block_action else 1
assert trace["messages"][0] == {
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt,
}
],
}
if not is_block_action:
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
if is_block_action
else "log"
)
def is_refusal(chunk):
return (
len(chunk.candidates) == 1
@@ -412,15 +412,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
**shrek_request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The response did not pass the guardrails" in str(
exc_info.value
)
# Only the block guardrail should be triggered here
assert "ogre detected in response" in str(exc_info.value)
assert "Fiona detected in response" not in str(exc_info.value)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
@@ -467,3 +464,124 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
)
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
@pytest.mark.parametrize(
"do_stream, is_block_action",
[(True, True), (True, False), (False, True), (False, False)],
)
async def test_preguardrailing_with_guardrails_from_explorer(
explorer_api_url, gateway_url, do_stream, is_block_action
):
"""Test that the guardrails from the explorer work."""
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
client = get_open_ai_client(
gateway_url, push_to_explorer=True, dataset_name=dataset_name
)
dataset_creation_response = await create_dataset(
explorer_api_url,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
dataset_name=dataset_name,
)
dataset_id = dataset_creation_response["id"]
_ = await add_guardrail_to_dataset(
explorer_api_url,
dataset_id=dataset_id,
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
action="block" if is_block_action else "log",
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
)
user_prompt = "Tell me a one sentence pun."
request = {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": user_prompt,
}
],
"max_tokens": 100,
}
if is_block_action:
if do_stream:
with pytest.raises(APIError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=True,
)
for _ in chat_response:
pass
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
else:
with pytest.raises(BadRequestError) as exc_info:
chat_response = client.chat.completions.create(
**request,
stream=False,
)
assert exc_info.value.status_code == 400
assert "[Invariant] The request did not pass the guardrails" in str(
exc_info.value
)
assert "pun detected in user message" in str(exc_info.value)
else:
if do_stream:
_ = client.chat.completions.create(
**request,
stream=True,
)
else:
_ = client.chat.completions.create(
**request,
stream=False,
)
# Wait for the trace to be saved
# This is needed because the trace is saved asynchronously
time.sleep(2)
# Fetch the trace ids for the dataset
traces_response = requests.get(
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
timeout=5,
)
traces = traces_response.json()
assert len(traces) == 1
trace_id = traces[0]["id"]
# Fetch the trace
trace_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}",
timeout=5,
)
trace = trace_response.json()
assert len(trace["messages"]) == 1 if is_block_action else 2
assert trace["messages"][0] == {
"role": "user",
"content": user_prompt,
}
if not is_block_action:
assert trace["messages"][1].get("role") == "assistant"
# Fetch annotations
annotations_response = requests.get(
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
timeout=5,
)
annotations = annotations_response.json()
assert len(annotations) == 1
assert (
annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
if is_block_action
else "log"
)