feat(Update rl_model tests):

This commit is contained in:
Alexander Myasoedov
2025-02-05 17:09:17 +02:00
parent b18427aa7e
commit 81ff6656e1
3 changed files with 66 additions and 62 deletions
+11 -19
View File
@@ -11,10 +11,8 @@ class PromptSelectionInterface(ABC):
"""Abstract base class for prompt selection strategies."""
@abstractmethod
def select_next_prompt(
self, current_prompt: str, model_response: str | None = None
) -> str:
"""Selects the next prompt based on current state and optional model response."""
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
"""Selects the next prompt based on current state and guard result."""
pass
@abstractmethod
@@ -23,7 +21,7 @@ class PromptSelectionInterface(ABC):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: str | None = None,
passed_guard: bool,
) -> None:
"""Updates internal rewards based on the outcome of the last selected prompt."""
pass
@@ -38,9 +36,7 @@ class RandomPromptSelector(PromptSelectionInterface):
self.prompts = prompts
self.history: Deque[str] = deque(maxlen=history_size)
def select_next_prompt(
self, current_prompt: str, model_response: str | None = None
) -> str:
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
self.history.append(current_prompt)
available = [p for p in self.prompts if p not in self.history]
@@ -55,7 +51,7 @@ class RandomPromptSelector(PromptSelectionInterface):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: str | None = None,
passed_guard: bool,
) -> None:
pass # No learning in random selection
@@ -79,9 +75,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
self.history: Deque[str] = deque(maxlen=history_size)
self.timeout = timeout
def select_next_prompt(
self, current_prompt: str, model_response: str | None = None
) -> str:
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
self.history.append(current_prompt)
try:
@@ -89,7 +83,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
f"{self.api_url}/rl-select",
json={
"current_prompt": current_prompt,
"model_response": model_response,
"passed_guard": passed_guard,
"history": list(self.history),
},
headers=self.headers,
@@ -114,7 +108,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: str | None = None,
passed_guard: bool,
) -> None:
try:
requests.post(
@@ -123,7 +117,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
"previous_prompt": previous_prompt,
"current_prompt": current_prompt,
"reward": reward,
"model_response": model_response,
"passed_guard": passed_guard,
},
headers=self.headers,
timeout=self.timeout,
@@ -166,9 +160,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
for state in prompts
}
def select_next_prompt(
self, current_prompt: str, model_response: str | None = None
) -> str:
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
self.history.append(current_prompt)
available = [a for a in self.prompts if a not in self.history]
@@ -194,7 +186,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: str | None = None,
passed_guard: bool,
) -> None:
if (
previous_prompt not in self.q_table
@@ -3,6 +3,7 @@ from unittest.mock import Mock, patch
import numpy as np
import pytest
import requests
# Import the classes to be tested
from agentic_security.probe_data.modules.rl_model import (
@@ -37,30 +38,16 @@ class TestRandomPromptSelector:
assert isinstance(selector.history, deque)
assert selector.history.maxlen == 3
def test_select_next_prompt_no_history(self, dataset_prompts):
def test_select_next_prompt(self, dataset_prompts):
selector = RandomPromptSelector(dataset_prompts)
current_prompt = "What is AI?"
next_prompt = selector.select_next_prompt(current_prompt)
assert next_prompt in dataset_prompts
assert next_prompt != current_prompt # Ensure no immediate repetition
def test_select_next_prompt_with_history(self, dataset_prompts):
selector = RandomPromptSelector(dataset_prompts)
selector.history.extend(["What is AI?", "How does RL work?"])
next_prompt = selector.select_next_prompt("Explain supervised learning.")
assert next_prompt not in selector.history
def test_select_next_prompt_reset_history(self, dataset_prompts):
selector = RandomPromptSelector(dataset_prompts, history_size=2)
selector.history.extend(["What is AI?", "How does RL work?"])
next_prompt = selector.select_next_prompt("Explain supervised learning.")
assert len(selector.history) == 2
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
assert next_prompt in dataset_prompts
assert next_prompt != current_prompt
def test_update_rewards_no_op(self, dataset_prompts):
selector = RandomPromptSelector(dataset_prompts)
selector.update_rewards("What is AI?", "How does RL work?", 1.0)
# No state changes expected
selector.update_rewards("What is AI?", "How does RL work?", 1.0, True)
assert len(selector.history) == 0
@@ -77,16 +64,17 @@ class TestCloudRLPromptSelector:
mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"}
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
next_prompt = selector.select_next_prompt("How does RL work?")
next_prompt = selector.select_next_prompt(
"How does RL work?", passed_guard=True
)
assert next_prompt == "What is AI?"
mock_requests.assert_called_once()
def test_update_rewards_success(self, dataset_prompts, mock_requests):
mock_requests.return_value.status_code = 200
def test_fallback_on_failure(self, dataset_prompts, mock_requests):
mock_requests.side_effect = requests.exceptions.RequestException
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
selector.update_rewards("What is AI?", "How does RL work?", 1.0)
mock_requests.assert_called_once()
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
assert next_prompt in dataset_prompts
# Tests for QLearningPromptSelector
@@ -102,19 +90,19 @@ class TestQLearningPromptSelector:
def test_select_next_prompt_exploration(self, dataset_prompts):
selector = QLearningPromptSelector(dataset_prompts, initial_exploration=1.0)
next_prompt = selector.select_next_prompt("What is AI?")
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
assert next_prompt in dataset_prompts
assert next_prompt != "What is AI?"
def test_select_next_prompt_exploitation(self, dataset_prompts):
selector = QLearningPromptSelector(dataset_prompts, initial_exploration=0.0)
selector.q_table["What is AI?"]["How does RL work?"] = 10.0 # Set high Q-value
next_prompt = selector.select_next_prompt("What is AI?")
selector.q_table["What is AI?"]["How does RL work?"] = 10.0
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
assert next_prompt == "How does RL work?"
def test_update_rewards(self, dataset_prompts):
selector = QLearningPromptSelector(dataset_prompts)
selector.update_rewards("What is AI?", "How does RL work?", 1.0)
selector.update_rewards("What is AI?", "How does RL work?", 1.0, True)
assert selector.q_table["What is AI?"]["How does RL work?"] > 0.0
def test_exploration_rate_decay(self, dataset_prompts):
@@ -122,21 +110,11 @@ class TestQLearningPromptSelector:
dataset_prompts, initial_exploration=1.0, exploration_decay=0.9
)
assert selector.exploration_rate == 1.0
selector.select_next_prompt("What is AI?")
selector.select_next_prompt("What is AI?", passed_guard=True)
assert selector.exploration_rate == 0.9
selector.select_next_prompt("How does RL work?")
selector.select_next_prompt("How does RL work?", passed_guard=True)
assert selector.exploration_rate == 0.81
def test_min_exploration_rate(self, dataset_prompts):
selector = QLearningPromptSelector(
dataset_prompts,
initial_exploration=0.1,
exploration_decay=0.5,
min_exploration=0.05,
)
selector.select_next_prompt("What is AI?")
assert selector.exploration_rate == 0.05 # Should not go below min_exploration
# Edge Cases and Error Handling
def test_empty_prompts():
@@ -146,10 +124,10 @@ def test_empty_prompts():
def test_cloud_rl_selector_invalid_url(dataset_prompts):
selector = CloudRLPromptSelector(dataset_prompts, "invalid_url", "token")
next_prompt = selector.select_next_prompt("What is AI?")
assert next_prompt in dataset_prompts # Should fallback to random selection
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
assert next_prompt in dataset_prompts
def test_q_learning_selector_invalid_reward(dataset_prompts):
selector = QLearningPromptSelector(dataset_prompts)
selector.update_rewards("What is AI?", "How does RL work?", np.nan)
selector.update_rewards("What is AI?", "How does RL work?", np.nan, True)
+34
View File
@@ -0,0 +1,34 @@
import pytest
import requests
from agentic_security.probe_data import REGISTRY
@pytest.mark.parametrize("dataset", REGISTRY)
def test_registry_accessibility(dataset):
"""
Validate that datasets from REGISTRY are accessible.
- If it's a URL, check if the response status is 200.
- If it's a cloud-hosted dataset, skip the test.
"""
dataset_name = dataset.get("dataset_name", "Unknown Dataset")
dataset_url = dataset.get("url")
if not dataset_url:
pytest.fail(f"Dataset {dataset_name} is missing a URL.")
if dataset_url.lower() == "cloud":
pytest.skip(f"Skipping cloud dataset: {dataset_name}")
if isinstance(dataset_url, str) and dataset_url.startswith("http"):
try:
response = requests.head(
dataset_url, timeout=5
) # HEAD request for efficiency
assert (
response.status_code == 200
), f"Dataset URL is inaccessible: {dataset_url}"
except requests.exceptions.RequestException as e:
pytest.fail(f"Request failed for {dataset_name} ({dataset_url}): {e}")
else:
pytest.fail(f"Unexpected URL format for {dataset_name}: {dataset_url}")