mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-23 15:29:43 +02:00
extend tests for input guardrailing
This commit is contained in:
committed by
Hemang
parent
2a66582c7c
commit
7f820bd79f
@@ -500,7 +500,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse):
|
||||
Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"error": "[Invariant] The response did not pass the guardrails",
|
||||
"error": "[Invariant] The request did not pass the guardrails",
|
||||
"details": self.guardrails_execution_result,
|
||||
}
|
||||
),
|
||||
|
||||
@@ -244,3 +244,108 @@ async def test_tool_call_guardrail_from_file(
|
||||
== "get_capital is called with Germany as argument"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set")
|
||||
@pytest.mark.parametrize(
|
||||
"do_stream, push_to_explorer",
|
||||
[(True, True), (True, False), (False, True), (False, False)],
|
||||
)
|
||||
async def test_input_from_guardrail_from_file(
|
||||
explorer_api_url, gateway_url, do_stream, push_to_explorer
|
||||
):
|
||||
"""Test the message content guardrail."""
|
||||
if not os.getenv("INVARIANT_API_KEY"):
|
||||
pytest.fail("No INVARIANT_API_KEY set, failing")
|
||||
|
||||
dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}"
|
||||
|
||||
client = OpenAI(
|
||||
http_client=Client(
|
||||
headers={
|
||||
"Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}"
|
||||
},
|
||||
),
|
||||
base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai"
|
||||
if push_to_explorer
|
||||
else f"{gateway_url}/api/v1/gateway/openai",
|
||||
)
|
||||
|
||||
request = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Tell me more about Fight Club."}],
|
||||
}
|
||||
|
||||
if not do_stream:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "[Invariant] The request did not pass the guardrails" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
else:
|
||||
with pytest.raises(APIError) as exc_info:
|
||||
chat_response = client.chat.completions.create(
|
||||
**request,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for _ in chat_response:
|
||||
pass
|
||||
assert (
|
||||
"[Invariant] The request did not pass the guardrails"
|
||||
in exc_info.value.message
|
||||
)
|
||||
assert "Users must not mention the magic phrase 'Fight Club'" in str(
|
||||
exc_info.value.body
|
||||
)
|
||||
|
||||
if push_to_explorer:
|
||||
# Wait for the trace to be saved
|
||||
# This is needed because the trace is saved asynchronously
|
||||
time.sleep(2)
|
||||
|
||||
# Fetch the trace ids for the dataset
|
||||
traces_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
||||
timeout=5,
|
||||
)
|
||||
traces = traces_response.json()
|
||||
assert len(traces) == 1
|
||||
trace_id = traces[0]["id"]
|
||||
|
||||
# Fetch the trace
|
||||
trace_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
||||
timeout=5,
|
||||
)
|
||||
trace = trace_response.json()
|
||||
|
||||
# in case of input guardrailing, the pushed trace will not contain a response
|
||||
assert len(trace["messages"]) == 1
|
||||
assert trace["messages"][0] == {
|
||||
"role": "user",
|
||||
"content": "Tell me more about Fight Club.",
|
||||
}
|
||||
|
||||
# Fetch annotations
|
||||
annotations_response = requests.get(
|
||||
f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations",
|
||||
timeout=5,
|
||||
)
|
||||
annotations = annotations_response.json()
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert (
|
||||
annotations[0]["content"]
|
||||
== "Users must not mention the magic phrase 'Fight Club'"
|
||||
and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
|
||||
)
|
||||
|
||||
@@ -13,4 +13,10 @@ raise "Dublin detected in the response" if:
|
||||
raise "get_capital is called with Germany as argument" if:
|
||||
(call: ToolCall)
|
||||
call is tool:get_capital
|
||||
call.function.arguments["country_name"] == "Germany"
|
||||
call.function.arguments["country_name"] == "Germany"
|
||||
|
||||
# For input guardrailing specifically
|
||||
raise "Users must not mention the magic phrase 'Fight Club'" if:
|
||||
(msg: Message)
|
||||
msg.role == "user"
|
||||
"Fight Club" in msg.content
|
||||
Reference in New Issue
Block a user