From d56b406e1a4b22a878d4155bd69565ef61301e92 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Tue, 9 Dec 2025 20:00:04 +0200 Subject: [PATCH] fix(tests runtime): --- .../probe_data/modules/rl_model.py | 8 +++-- .../probe_data/modules/test_rl_model.py | 31 +++++++++++++------ tests/probe_actor/test_fuzzer.py | 9 +++++- tests/routes/test_probe.py | 4 ++- tests/test_lib.py | 24 ++++++++++++-- 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/agentic_security/probe_data/modules/rl_model.py b/agentic_security/probe_data/modules/rl_model.py index 002158c..027e56c 100644 --- a/agentic_security/probe_data/modules/rl_model.py +++ b/agentic_security/probe_data/modules/rl_model.py @@ -206,7 +206,11 @@ class QLearningPromptSelector(PromptSelectionInterface): class Module: def __init__( - self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {} + self, + prompt_groups: list[str], + tools_inbox: asyncio.Queue, + opts: dict = {}, + rl_model: PromptSelectionInterface | None = None, ): self.tools_inbox = tools_inbox self.opts = opts @@ -214,7 +218,7 @@ class Module: self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts self.run_id = U.uuid4().hex self.batch_size = self.opts.get("batch_size", 500) - self.rl_model = CloudRLPromptSelector( + self.rl_model = rl_model or CloudRLPromptSelector( prompt_groups, "https://mcp.metaheuristic.co", run_id=self.run_id ) diff --git a/agentic_security/probe_data/modules/test_rl_model.py b/agentic_security/probe_data/modules/test_rl_model.py index 4f1aaf8..9cb848b 100644 --- a/agentic_security/probe_data/modules/test_rl_model.py +++ b/agentic_security/probe_data/modules/test_rl_model.py @@ -33,11 +33,17 @@ def mock_requests() -> Mock: @pytest.fixture -def mock_rl_selector() -> Mock: - return CloudRLPromptSelector( - dataset_prompts, - api_url="https://mcp.metaheuristic.co", - ) +def mock_rl_selector(dataset_prompts) -> Mock: + class StubSelector: + def __init__(self, prompts: list[str]): + self.prompts = prompts + self.idx = 0 + + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: + self.idx = (self.idx + 1) % len(self.prompts) + return [self.prompts[self.idx]] + + return StubSelector(dataset_prompts) @pytest.fixture @@ -91,7 +97,10 @@ class TestCloudRLPromptSelector: next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True) assert next_prompt in dataset_prompts - def test_select_next_prompt_success_service(self, dataset_prompts): + def test_select_next_prompt_success_service(self, dataset_prompts, mock_requests): + mock_requests.return_value.status_code = 200 + mock_requests.return_value.json.return_value = {"next_prompts": ["What is AI?"]} + selector = CloudRLPromptSelector( dataset_prompts, api_url="https://mcp.metaheuristic.co", @@ -99,7 +108,7 @@ class TestCloudRLPromptSelector: next_prompt = selector.select_next_prompt( "How does RL work?", passed_guard=True ) - assert next_prompt + assert next_prompt == "What is AI?" # Tests for QLearningPromptSelector @@ -188,7 +197,7 @@ class TestModule: async def test_apply_basic_flow( self, dataset_prompts, tools_inbox, mock_rl_selector ): - module = Module(dataset_prompts, tools_inbox) + module = Module(dataset_prompts, tools_inbox, rl_model=mock_rl_selector) count = 0 async for prompt in module.apply(): @@ -198,7 +207,9 @@ class TestModule: break @pytest.mark.asyncio - async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox): + async def test_apply_rl_with_tools_inbox( + self, dataset_prompts, tools_inbox, mock_rl_selector + ): # Add a test message to the tools inbox test_message = { "message": "Test message", @@ -207,7 +218,7 @@ class TestModule: } await tools_inbox.put(test_message) - module = Module(dataset_prompts, tools_inbox) + module = Module(dataset_prompts, tools_inbox, rl_model=mock_rl_selector) async for output in module.apply(): if output == "Test message": diff --git a/tests/probe_actor/test_fuzzer.py b/tests/probe_actor/test_fuzzer.py index 30cabba..24e6018 100644 --- a/tests/probe_actor/test_fuzzer.py +++ b/tests/probe_actor/test_fuzzer.py @@ -76,14 +76,21 @@ async def test_perform_single_shot_scan_success(prepare_prompts_mock): @pytest.mark.asyncio +@patch("agentic_security.probe_data.msj_data.prepare_prompts") @patch("agentic_security.probe_data.data.prepare_prompts") -async def test_perform_many_shot_scan_probe_injection(prepare_prompts_mock): +async def test_perform_many_shot_scan_probe_injection( + prepare_prompts_mock, msj_prepare_prompts_mock +): # Mock main and probe prompt modules prepare_prompts_mock.side_effect = [ [MagicMock(dataset_name="main_module", prompts=["main_prompt1"], lazy=False)], [MagicMock(dataset_name="probe_module", prompts=["probe_prompt1"], lazy=False)], ] + msj_prepare_prompts_mock.return_value = [ + MagicMock(dataset_name="msj_probe_module", prompts=["msj_probe_prompt"], lazy=False) + ] + # Mock request_factory mock_response = AsyncMock() mock_response.fn.side_effect = [ diff --git a/tests/routes/test_probe.py b/tests/routes/test_probe.py index 3525997..fd1fe5b 100644 --- a/tests/routes/test_probe.py +++ b/tests/routes/test_probe.py @@ -1,5 +1,6 @@ import base64 import io +import random import httpx import pytest @@ -85,8 +86,9 @@ def test_data_config_endpoint(): def test_refusal_rate(): """Test that refusal rate is approximately 20%""" + random.seed(0) refusal_count = 0 - total_trials = 1000 + total_trials = 200 for _ in range(total_trials): response = client.post("/v1/self-probe", json={"prompt": "test"}) diff --git a/tests/test_lib.py b/tests/test_lib.py index 20db36b..f341be7 100644 --- a/tests/test_lib.py +++ b/tests/test_lib.py @@ -1,6 +1,7 @@ import importlib import os import signal +import socket import subprocess import tempfile import time @@ -24,12 +25,29 @@ def test_server(request): preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN), ) - # Give the server time to start - time.sleep(2) + def wait_for_port(host: str, port: int, timeout: float = 5.0) -> bool: + start = time.time() + while time.time() - start < timeout: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + try: + sock.connect((host, port)) + return True + except OSError: + time.sleep(0.1) + return False + + if not wait_for_port("127.0.0.1", 9094): + server.kill() + pytest.skip("Test server failed to start within timeout") def cleanup(): server.terminate() - server.wait() + try: + server.wait(timeout=3) + except subprocess.TimeoutExpired: + server.kill() + server.wait(timeout=2) request.addfinalizer(cleanup) return server