From 7f820bd79f5ec5565cf0fa2cd86fce19e23c9e22 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Fri, 28 Mar 2025 20:10:54 +0100 Subject: [PATCH] extend tests for input guardrailing --- gateway/routes/open_ai.py | 2 +- .../guardrails/test_guardrails_open_ai.py | 105 ++++++++++++++++++ .../guardrails/find_capital_guardrails.py | 8 +- 3 files changed, 113 insertions(+), 2 deletions(-) diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index e8fbf0f..629dc6b 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -500,7 +500,7 @@ class InstrumentedOpenAIResponse(InstrumentedResponse): Response( content=json.dumps( { - "error": "[Invariant] The response did not pass the guardrails", + "error": "[Invariant] The request did not pass the guardrails", "details": self.guardrails_execution_result, } ), diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index c013107..acc2f67 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -244,3 +244,108 @@ async def test_tool_call_guardrail_from_file( == "get_capital is called with Germany as argument" and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, push_to_explorer", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_input_from_guardrail_from_file( + explorer_api_url, gateway_url, do_stream, push_to_explorer +): + """Test the message content guardrail.""" + if not os.getenv("INVARIANT_API_KEY"): + pytest.fail("No INVARIANT_API_KEY set, failing") + + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + + client = OpenAI( + http_client=Client( + headers={ + "Invariant-Authorization": f"Bearer {os.getenv('INVARIANT_API_KEY')}" + }, + ), + base_url=f"{gateway_url}/api/v1/gateway/{dataset_name}/openai" + if push_to_explorer + else f"{gateway_url}/api/v1/gateway/openai", + ) + + request = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Tell me more about Fight Club."}], + } + + if not do_stream: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value + ) + + else: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=True, + ) + + for _ in chat_response: + pass + assert ( + "[Invariant] The request did not pass the guardrails" + in exc_info.value.message + ) + assert "Users must not mention the magic phrase 'Fight Club'" in str( + exc_info.value.body + ) + + if push_to_explorer: + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + # in case of input guardrailing, the pushed trace will not contain a response + assert len(trace["messages"]) == 1 + assert trace["messages"][0] == { + "role": "user", + "content": "Tell me more about Fight Club.", + } + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] + == "Users must not mention the magic phrase 'Fight Club'" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + ) diff --git a/tests/integration/resources/guardrails/find_capital_guardrails.py b/tests/integration/resources/guardrails/find_capital_guardrails.py index 8a9caf3..282720f 100644 --- a/tests/integration/resources/guardrails/find_capital_guardrails.py +++ b/tests/integration/resources/guardrails/find_capital_guardrails.py @@ -13,4 +13,10 @@ raise "Dublin detected in the response" if: raise "get_capital is called with Germany as argument" if: (call: ToolCall) call is tool:get_capital - call.function.arguments["country_name"] == "Germany" \ No newline at end of file + call.function.arguments["country_name"] == "Germany" + +# For input guardrailing specifically +raise "Users must not mention the magic phrase 'Fight Club'" if: + (msg: Message) + msg.role == "user" + "Fight Club" in msg.content \ No newline at end of file