From 81ff6656e1218ca86689760f1582ffe094aa61ed Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Wed, 5 Feb 2025 17:09:17 +0200 Subject: [PATCH] feat(Update rl_model tests): --- .../probe_data/modules/rl_model.py | 30 ++++----- .../probe_data/modules/test_rl_model.py | 64 ++++++------------- agentic_security/test_registry.py | 34 ++++++++++ 3 files changed, 66 insertions(+), 62 deletions(-) create mode 100644 agentic_security/test_registry.py diff --git a/agentic_security/probe_data/modules/rl_model.py b/agentic_security/probe_data/modules/rl_model.py index b0f0159..f6e4dc2 100644 --- a/agentic_security/probe_data/modules/rl_model.py +++ b/agentic_security/probe_data/modules/rl_model.py @@ -11,10 +11,8 @@ class PromptSelectionInterface(ABC): """Abstract base class for prompt selection strategies.""" @abstractmethod - def select_next_prompt( - self, current_prompt: str, model_response: str | None = None - ) -> str: - """Selects the next prompt based on current state and optional model response.""" + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: + """Selects the next prompt based on current state and guard result.""" pass @abstractmethod @@ -23,7 +21,7 @@ class PromptSelectionInterface(ABC): previous_prompt: str, current_prompt: str, reward: float, - model_response: str | None = None, + passed_guard: bool, ) -> None: """Updates internal rewards based on the outcome of the last selected prompt.""" pass @@ -38,9 +36,7 @@ class RandomPromptSelector(PromptSelectionInterface): self.prompts = prompts self.history: Deque[str] = deque(maxlen=history_size) - def select_next_prompt( - self, current_prompt: str, model_response: str | None = None - ) -> str: + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: self.history.append(current_prompt) available = [p for p in self.prompts if p not in self.history] @@ -55,7 +51,7 @@ class RandomPromptSelector(PromptSelectionInterface): previous_prompt: str, current_prompt: str, reward: float, - model_response: str | None = None, + passed_guard: bool, ) -> None: pass # No learning in random selection @@ -79,9 +75,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): self.history: Deque[str] = deque(maxlen=history_size) self.timeout = timeout - def select_next_prompt( - self, current_prompt: str, model_response: str | None = None - ) -> str: + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: self.history.append(current_prompt) try: @@ -89,7 +83,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): f"{self.api_url}/rl-select", json={ "current_prompt": current_prompt, - "model_response": model_response, + "passed_guard": passed_guard, "history": list(self.history), }, headers=self.headers, @@ -114,7 +108,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): previous_prompt: str, current_prompt: str, reward: float, - model_response: str | None = None, + passed_guard: bool, ) -> None: try: requests.post( @@ -123,7 +117,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): "previous_prompt": previous_prompt, "current_prompt": current_prompt, "reward": reward, - "model_response": model_response, + "passed_guard": passed_guard, }, headers=self.headers, timeout=self.timeout, @@ -166,9 +160,7 @@ class QLearningPromptSelector(PromptSelectionInterface): for state in prompts } - def select_next_prompt( - self, current_prompt: str, model_response: str | None = None - ) -> str: + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: self.history.append(current_prompt) available = [a for a in self.prompts if a not in self.history] @@ -194,7 +186,7 @@ class QLearningPromptSelector(PromptSelectionInterface): previous_prompt: str, current_prompt: str, reward: float, - model_response: str | None = None, + passed_guard: bool, ) -> None: if ( previous_prompt not in self.q_table diff --git a/agentic_security/probe_data/modules/test_rl_model.py b/agentic_security/probe_data/modules/test_rl_model.py index 3663cea..bf8bc7c 100644 --- a/agentic_security/probe_data/modules/test_rl_model.py +++ b/agentic_security/probe_data/modules/test_rl_model.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import numpy as np import pytest +import requests # Import the classes to be tested from agentic_security.probe_data.modules.rl_model import ( @@ -37,30 +38,16 @@ class TestRandomPromptSelector: assert isinstance(selector.history, deque) assert selector.history.maxlen == 3 - def test_select_next_prompt_no_history(self, dataset_prompts): + def test_select_next_prompt(self, dataset_prompts): selector = RandomPromptSelector(dataset_prompts) current_prompt = "What is AI?" - next_prompt = selector.select_next_prompt(current_prompt) - assert next_prompt in dataset_prompts - assert next_prompt != current_prompt # Ensure no immediate repetition - - def test_select_next_prompt_with_history(self, dataset_prompts): - selector = RandomPromptSelector(dataset_prompts) - selector.history.extend(["What is AI?", "How does RL work?"]) - next_prompt = selector.select_next_prompt("Explain supervised learning.") - assert next_prompt not in selector.history - - def test_select_next_prompt_reset_history(self, dataset_prompts): - selector = RandomPromptSelector(dataset_prompts, history_size=2) - selector.history.extend(["What is AI?", "How does RL work?"]) - next_prompt = selector.select_next_prompt("Explain supervised learning.") - assert len(selector.history) == 2 + next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True) assert next_prompt in dataset_prompts + assert next_prompt != current_prompt def test_update_rewards_no_op(self, dataset_prompts): selector = RandomPromptSelector(dataset_prompts) - selector.update_rewards("What is AI?", "How does RL work?", 1.0) - # No state changes expected + selector.update_rewards("What is AI?", "How does RL work?", 1.0, True) assert len(selector.history) == 0 @@ -77,16 +64,17 @@ class TestCloudRLPromptSelector: mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"} selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token") - next_prompt = selector.select_next_prompt("How does RL work?") + next_prompt = selector.select_next_prompt( + "How does RL work?", passed_guard=True + ) assert next_prompt == "What is AI?" mock_requests.assert_called_once() - def test_update_rewards_success(self, dataset_prompts, mock_requests): - mock_requests.return_value.status_code = 200 - + def test_fallback_on_failure(self, dataset_prompts, mock_requests): + mock_requests.side_effect = requests.exceptions.RequestException selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token") - selector.update_rewards("What is AI?", "How does RL work?", 1.0) - mock_requests.assert_called_once() + next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True) + assert next_prompt in dataset_prompts # Tests for QLearningPromptSelector @@ -102,19 +90,19 @@ class TestQLearningPromptSelector: def test_select_next_prompt_exploration(self, dataset_prompts): selector = QLearningPromptSelector(dataset_prompts, initial_exploration=1.0) - next_prompt = selector.select_next_prompt("What is AI?") + next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True) assert next_prompt in dataset_prompts assert next_prompt != "What is AI?" def test_select_next_prompt_exploitation(self, dataset_prompts): selector = QLearningPromptSelector(dataset_prompts, initial_exploration=0.0) - selector.q_table["What is AI?"]["How does RL work?"] = 10.0 # Set high Q-value - next_prompt = selector.select_next_prompt("What is AI?") + selector.q_table["What is AI?"]["How does RL work?"] = 10.0 + next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True) assert next_prompt == "How does RL work?" def test_update_rewards(self, dataset_prompts): selector = QLearningPromptSelector(dataset_prompts) - selector.update_rewards("What is AI?", "How does RL work?", 1.0) + selector.update_rewards("What is AI?", "How does RL work?", 1.0, True) assert selector.q_table["What is AI?"]["How does RL work?"] > 0.0 def test_exploration_rate_decay(self, dataset_prompts): @@ -122,21 +110,11 @@ class TestQLearningPromptSelector: dataset_prompts, initial_exploration=1.0, exploration_decay=0.9 ) assert selector.exploration_rate == 1.0 - selector.select_next_prompt("What is AI?") + selector.select_next_prompt("What is AI?", passed_guard=True) assert selector.exploration_rate == 0.9 - selector.select_next_prompt("How does RL work?") + selector.select_next_prompt("How does RL work?", passed_guard=True) assert selector.exploration_rate == 0.81 - def test_min_exploration_rate(self, dataset_prompts): - selector = QLearningPromptSelector( - dataset_prompts, - initial_exploration=0.1, - exploration_decay=0.5, - min_exploration=0.05, - ) - selector.select_next_prompt("What is AI?") - assert selector.exploration_rate == 0.05 # Should not go below min_exploration - # Edge Cases and Error Handling def test_empty_prompts(): @@ -146,10 +124,10 @@ def test_empty_prompts(): def test_cloud_rl_selector_invalid_url(dataset_prompts): selector = CloudRLPromptSelector(dataset_prompts, "invalid_url", "token") - next_prompt = selector.select_next_prompt("What is AI?") - assert next_prompt in dataset_prompts # Should fallback to random selection + next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True) + assert next_prompt in dataset_prompts def test_q_learning_selector_invalid_reward(dataset_prompts): selector = QLearningPromptSelector(dataset_prompts) - selector.update_rewards("What is AI?", "How does RL work?", np.nan) + selector.update_rewards("What is AI?", "How does RL work?", np.nan, True) diff --git a/agentic_security/test_registry.py b/agentic_security/test_registry.py new file mode 100644 index 0000000..91739a1 --- /dev/null +++ b/agentic_security/test_registry.py @@ -0,0 +1,34 @@ +import pytest +import requests +from agentic_security.probe_data import REGISTRY + + +@pytest.mark.parametrize("dataset", REGISTRY) +def test_registry_accessibility(dataset): + """ + Validate that datasets from REGISTRY are accessible. + - If it's a URL, check if the response status is 200. + - If it's a cloud-hosted dataset, skip the test. + """ + dataset_name = dataset.get("dataset_name", "Unknown Dataset") + dataset_url = dataset.get("url") + + if not dataset_url: + pytest.fail(f"Dataset {dataset_name} is missing a URL.") + + if dataset_url.lower() == "cloud": + pytest.skip(f"Skipping cloud dataset: {dataset_name}") + + if isinstance(dataset_url, str) and dataset_url.startswith("http"): + try: + response = requests.head( + dataset_url, timeout=5 + ) # HEAD request for efficiency + assert ( + response.status_code == 200 + ), f"Dataset URL is inaccessible: {dataset_url}" + except requests.exceptions.RequestException as e: + pytest.fail(f"Request failed for {dataset_name} ({dataset_url}): {e}") + + else: + pytest.fail(f"Unexpected URL format for {dataset_name}: {dataset_url}")