mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-07-01 16:55:31 +02:00
Add preguardrailing tests for guardrails pulled from explorer.
This commit is contained in:
@@ -536,6 +536,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse):
|
||||
)
|
||||
|
||||
async def on_end(self):
|
||||
"""Runs when the request ends."""
|
||||
response_string = json.dumps(self.response_json)
|
||||
response_code = self.response.status_code
|
||||
|
||||
|
||||
+16
-17
@@ -432,27 +432,26 @@ async def push_to_explorer(
|
||||
# or if the guardrails check returned errors.
|
||||
guardrails_execution_result = guardrails_execution_result or {}
|
||||
guardrails_errors = guardrails_execution_result.get("errors", [])
|
||||
if guardrails_errors or not (
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_errors, action="block"
|
||||
)
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
if annotations or not (
|
||||
merged_response.get("choices")
|
||||
and merged_response["choices"][0].get("finish_reason")
|
||||
not in FINISH_REASON_TO_PUSH_TRACE
|
||||
):
|
||||
annotations = create_annotations_from_guardrails_errors(
|
||||
guardrails_errors, action="block"
|
||||
)
|
||||
|
||||
# Execute the logging guardrails before pushing to Explorer
|
||||
logging_guardrails_execution_result = await get_guardrails_check_result(
|
||||
context,
|
||||
action=GuardrailAction.LOG,
|
||||
response_json=merged_response,
|
||||
)
|
||||
logging_annotations = create_annotations_from_guardrails_errors(
|
||||
logging_guardrails_execution_result.get("errors", []), action="log"
|
||||
)
|
||||
# Update the annotations with the logging guardrails
|
||||
annotations.extend(logging_annotations)
|
||||
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = list(context.request_json.get("messages", []))
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
|
||||
@@ -316,7 +316,10 @@ async def test_input_from_guardrail_from_file(
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize("do_stream", [True, False])
|
||||
async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
@@ -461,3 +464,127 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
||||
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, is_block_action",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
explorer_api_url, gateway_url, do_stream, is_block_action
|
||||
):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}"
|
||||
client = get_anthropic_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
|
||||
action="block" if is_block_action else "log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
user_prompt = "Tell me a one sentence pun."
|
||||
request = {
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if is_block_action:
|
||||
if do_stream:
|
||||
with pytest.raises(APIStatusError) as exc_info:
|
||||
chat_response = client.messages.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
else:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.messages.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
|
||||
else:
|
||||
if do_stream:
|
||||
_ = client.messages.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
_ = client.messages.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2 if not is_block_action else 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
if not is_block_action:
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "pun detected in user message"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
if is_block_action
|
||||
else "log"
|
||||
)
|
||||
|
||||
@@ -444,6 +444,118 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, is_block_action",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
explorer_api_url, gateway_url, do_stream, is_block_action
|
||||
):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-gemini-{uuid.uuid4()}"
|
||||
client = get_gemini_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
|
||||
action="block" if is_block_action else "log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
user_prompt = "Tell me a one sentence pun."
|
||||
request = {
|
||||
"model": "gemini-2.0-flash",
|
||||
"contents": user_prompt,
|
||||
"config": {
|
||||
"maxOutputTokens": 100,
|
||||
},
|
||||
}
|
||||
if is_block_action:
|
||||
if do_stream:
|
||||
chat_response = client.models.generate_content_stream(**request)
|
||||
|
||||
assert_is_streamed_refusal(
|
||||
chat_response,
|
||||
[
|
||||
"[Invariant] The request did not pass the guardrails",
|
||||
"pun detected in user message",
|
||||
],
|
||||
)
|
||||
else:
|
||||
with pytest.raises(genai.errors.ClientError) as exc_info:
|
||||
chat_response = client.models.generate_content(**request)
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
else:
|
||||
if do_stream:
|
||||
response = client.models.generate_content_stream(**request)
|
||||
for _ in response:
|
||||
pass
|
||||
else:
|
||||
_ = client.models.generate_content(**request)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 2 if not is_block_action else 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
if not is_block_action:
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "pun detected in user message"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
if is_block_action
|
||||
else "log"
|
||||
)
|
||||
|
||||
|
||||
def is_refusal(chunk):
|
||||
return (
|
||||
len(chunk.candidates) == 1
|
||||
|
||||
@@ -412,15 +412,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
||||
**shrek_request,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert "[Invariant] The response did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
# Only the block guardrail should be triggered here
|
||||
assert "ogre detected in response" in str(exc_info.value)
|
||||
assert "Fiona detected in response" not in str(exc_info.value)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
@@ -467,3 +464,124 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
|
||||
and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[1]["extra_metadata"]["guardrail-action"] == "log"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, is_block_action",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_preguardrailing_with_guardrails_from_explorer(
|
||||
explorer_api_url, gateway_url, do_stream, is_block_action
|
||||
):
|
||||
"""Test that the guardrails from the explorer work."""
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
client = get_open_ai_client(
|
||||
gateway_url, push_to_explorer=True, dataset_name=dataset_name
|
||||
)
|
||||
|
||||
dataset_creation_response = await create_dataset(
|
||||
explorer_api_url,
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
dataset_id = dataset_creation_response["id"]
|
||||
_ = await add_guardrail_to_dataset(
|
||||
explorer_api_url,
|
||||
dataset_id=dataset_id,
|
||||
policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"',
|
||||
action="block" if is_block_action else "log",
|
||||
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
|
||||
)
|
||||
|
||||
user_prompt = "Tell me a one sentence pun."
|
||||
request = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
if is_block_action:
|
||||
if do_stream:
|
||||
with pytest.raises(APIError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
for _ in chat_response:
|
||||
pass
|
||||
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
else:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pun detected in user message" in str(exc_info.value)
|
||||
else:
|
||||
if do_stream:
|
||||
_ = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
_ = client.chat.completions.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
assert len(trace["messages"]) == 1 if is_block_action else 2
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
if not is_block_action:
|
||||
assert trace["messages"][1].get("role") == "assistant"
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"] == "pun detected in user message"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
and annotations[0]["extra_metadata"]["guardrail-action"] == "block"
|
||||
if is_block_action
|
||||
else "log"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user