diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 6d4a409..2643125 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -536,6 +536,7 @@ class InstrumentedGeminiResponse(InstrumentedResponse): ) async def on_end(self): + """Runs when the request ends.""" response_string = json.dumps(self.response_json) response_code = self.response.status_code diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index f929a2c..ff565e0 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -432,27 +432,26 @@ async def push_to_explorer( # or if the guardrails check returned errors. guardrails_execution_result = guardrails_execution_result or {} guardrails_errors = guardrails_execution_result.get("errors", []) - if guardrails_errors or not ( + annotations = create_annotations_from_guardrails_errors( + guardrails_errors, action="block" + ) + # Execute the logging guardrails before pushing to Explorer + logging_guardrails_execution_result = await get_guardrails_check_result( + context, + action=GuardrailAction.LOG, + response_json=merged_response, + ) + logging_annotations = create_annotations_from_guardrails_errors( + logging_guardrails_execution_result.get("errors", []), action="log" + ) + # Update the annotations with the logging guardrails + annotations.extend(logging_annotations) + + if annotations or not ( merged_response.get("choices") and merged_response["choices"][0].get("finish_reason") not in FINISH_REASON_TO_PUSH_TRACE ): - annotations = create_annotations_from_guardrails_errors( - guardrails_errors, action="block" - ) - - # Execute the logging guardrails before pushing to Explorer - logging_guardrails_execution_result = await get_guardrails_check_result( - context, - action=GuardrailAction.LOG, - response_json=merged_response, - ) - logging_annotations = create_annotations_from_guardrails_errors( - logging_guardrails_execution_result.get("errors", []), action="log" - ) - # Update the annotations with the logging guardrails - annotations.extend(logging_annotations) - # Combine the messages from the request body and the choices from the OpenAI response messages = list(context.request_json.get("messages", [])) messages += [choice["message"] for choice in merged_response.get("choices", [])] diff --git a/tests/integration/guardrails/test_guardrails_anthropic.py b/tests/integration/guardrails/test_guardrails_anthropic.py index f61d9e6..035a845 100644 --- a/tests/integration/guardrails/test_guardrails_anthropic.py +++ b/tests/integration/guardrails/test_guardrails_anthropic.py @@ -316,7 +316,10 @@ async def test_input_from_guardrail_from_file( and annotations[0]["extra_metadata"]["source"] == "guardrails-error" ) -@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set") + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) @pytest.mark.parametrize("do_stream", [True, False]) async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_stream): """Test that the guardrails from the explorer work.""" @@ -461,3 +464,127 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["guardrail-action"] == "log" ) + + +@pytest.mark.skipif( + not os.getenv("ANTHROPIC_API_KEY"), reason="No ANTHROPIC_API_KEY set" +) +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-anthropic-{uuid.uuid4()}" + client = get_anthropic_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "claude-3-5-sonnet-20241022", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if is_block_action: + if do_stream: + with pytest.raises(APIStatusError) as exc_info: + chat_response = client.messages.create( + **request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + else: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.messages.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + + else: + if do_stream: + _ = client.messages.create( + **request, + stream=True, + ) + else: + _ = client.messages.create( + **request, + stream=False, + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 if not is_block_action else 1 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) diff --git a/tests/integration/guardrails/test_guardrails_gemini.py b/tests/integration/guardrails/test_guardrails_gemini.py index 6fc0945..b3ac35e 100644 --- a/tests/integration/guardrails/test_guardrails_gemini.py +++ b/tests/integration/guardrails/test_guardrails_gemini.py @@ -444,6 +444,118 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s ) +@pytest.mark.skipif(not os.getenv("GEMINI_API_KEY"), reason="No GEMINI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-gemini-{uuid.uuid4()}" + client = get_gemini_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "gemini-2.0-flash", + "contents": user_prompt, + "config": { + "maxOutputTokens": 100, + }, + } + if is_block_action: + if do_stream: + chat_response = client.models.generate_content_stream(**request) + + assert_is_streamed_refusal( + chat_response, + [ + "[Invariant] The request did not pass the guardrails", + "pun detected in user message", + ], + ) + else: + with pytest.raises(genai.errors.ClientError) as exc_info: + chat_response = client.models.generate_content(**request) + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + else: + if do_stream: + response = client.models.generate_content_stream(**request) + for _ in response: + pass + else: + _ = client.models.generate_content(**request) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 2 if not is_block_action else 1 + assert trace["messages"][0] == { + "role": "user", + "content": [ + { + "type": "text", + "text": user_prompt, + } + ], + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + ) + + def is_refusal(chunk): return ( len(chunk.candidates) == 1 diff --git a/tests/integration/guardrails/test_guardrails_open_ai.py b/tests/integration/guardrails/test_guardrails_open_ai.py index b0c6b24..6031778 100644 --- a/tests/integration/guardrails/test_guardrails_open_ai.py +++ b/tests/integration/guardrails/test_guardrails_open_ai.py @@ -412,15 +412,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s **shrek_request, stream=True, ) - for _ in chat_response: pass + assert "[Invariant] The response did not pass the guardrails" in str( exc_info.value ) - # Only the block guardrail should be triggered here - assert "ogre detected in response" in str(exc_info.value) - assert "Fiona detected in response" not in str(exc_info.value) # Wait for the trace to be saved # This is needed because the trace is saved asynchronously @@ -467,3 +464,124 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["guardrail-action"] == "log" ) + + +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OPENAI_API_KEY set") +@pytest.mark.parametrize( + "do_stream, is_block_action", + [(True, True), (True, False), (False, True), (False, False)], +) +async def test_preguardrailing_with_guardrails_from_explorer( + explorer_api_url, gateway_url, do_stream, is_block_action +): + """Test that the guardrails from the explorer work.""" + dataset_name = f"test-dataset-open-ai-{uuid.uuid4()}" + client = get_open_ai_client( + gateway_url, push_to_explorer=True, dataset_name=dataset_name + ) + + dataset_creation_response = await create_dataset( + explorer_api_url, + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + dataset_name=dataset_name, + ) + dataset_id = dataset_creation_response["id"] + _ = await add_guardrail_to_dataset( + explorer_api_url, + dataset_id=dataset_id, + policy='raise "pun detected in user message" if:\n (msg: Message)\n "pun" in msg.content and msg.role == "user"', + action="block" if is_block_action else "log", + invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"), + ) + + user_prompt = "Tell me a one sentence pun." + request = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": user_prompt, + } + ], + "max_tokens": 100, + } + if is_block_action: + if do_stream: + with pytest.raises(APIError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=True, + ) + for _ in chat_response: + pass + + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + else: + with pytest.raises(BadRequestError) as exc_info: + chat_response = client.chat.completions.create( + **request, + stream=False, + ) + + assert exc_info.value.status_code == 400 + assert "[Invariant] The request did not pass the guardrails" in str( + exc_info.value + ) + assert "pun detected in user message" in str(exc_info.value) + else: + if do_stream: + _ = client.chat.completions.create( + **request, + stream=True, + ) + else: + _ = client.chat.completions.create( + **request, + stream=False, + ) + + # Wait for the trace to be saved + # This is needed because the trace is saved asynchronously + time.sleep(2) + + # Fetch the trace ids for the dataset + traces_response = requests.get( + f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces", + timeout=5, + ) + traces = traces_response.json() + assert len(traces) == 1 + trace_id = traces[0]["id"] + + # Fetch the trace + trace_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}", + timeout=5, + ) + trace = trace_response.json() + + assert len(trace["messages"]) == 1 if is_block_action else 2 + assert trace["messages"][0] == { + "role": "user", + "content": user_prompt, + } + if not is_block_action: + assert trace["messages"][1].get("role") == "assistant" + + # Fetch annotations + annotations_response = requests.get( + f"{explorer_api_url}/api/v1/trace/{trace_id}/annotations", + timeout=5, + ) + annotations = annotations_response.json() + + assert len(annotations) == 1 + assert ( + annotations[0]["content"] == "pun detected in user message" + and annotations[0]["extra_metadata"]["source"] == "guardrails-error" + and annotations[0]["extra_metadata"]["guardrail-action"] == "block" + if is_block_action + else "log" + )