diff --git a/agentic_security/probe_data/modules/rl_model.py b/agentic_security/probe_data/modules/rl_model.py new file mode 100644 index 0000000..44a0286 --- /dev/null +++ b/agentic_security/probe_data/modules/rl_model.py @@ -0,0 +1,210 @@ +from abc import ABC, abstractmethod +import numpy as np +import random +import requests +from collections import deque +from typing import List, Optional, Deque, Dict + + +class PromptSelectionInterface(ABC): + """Abstract base class for prompt selection strategies.""" + + @abstractmethod + def select_next_prompt( + self, current_prompt: str, model_response: Optional[str] = None + ) -> str: + """Selects the next prompt based on current state and optional model response.""" + pass + + @abstractmethod + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + model_response: Optional[str] = None, + ) -> None: + """Updates internal rewards based on the outcome of the last selected prompt.""" + pass + + +class RandomPromptSelector(PromptSelectionInterface): + """Random prompt selector with cycle prevention using history.""" + + def __init__(self, prompts: List[str], history_size: int = 3): + if not prompts: + raise ValueError("Prompts list cannot be empty") + self.prompts = prompts + self.history: Deque[str] = deque(maxlen=history_size) + + def select_next_prompt( + self, current_prompt: str, model_response: Optional[str] = None + ) -> str: + self.history.append(current_prompt) + available = [p for p in self.prompts if p not in self.history] + + if not available: + available = self.prompts + self.history.clear() + + return random.choice(available) + + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + model_response: Optional[str] = None, + ) -> None: + pass # No learning in random selection + + +class CloudRLPromptSelector(PromptSelectionInterface): + """Cloud-based reinforcement learning prompt selector with fallback.""" + + def __init__( + self, + prompts: List[str], + api_url: str, + auth_token: str, + history_size: int = 3, + timeout: int = 5, + ): + if not prompts: + raise ValueError("Prompts list cannot be empty") + self.prompts = prompts + self.api_url = api_url + self.headers = {"Authorization": f"Bearer {auth_token}"} + self.history: Deque[str] = deque(maxlen=history_size) + self.timeout = timeout + + def select_next_prompt( + self, current_prompt: str, model_response: Optional[str] = None + ) -> str: + self.history.append(current_prompt) + + try: + response = requests.post( + f"{self.api_url}/rl-select", + json={ + "current_prompt": current_prompt, + "model_response": model_response, + "history": list(self.history), + }, + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + return response.json().get("next_prompt", self._fallback_selection()) + except requests.exceptions.RequestException as e: + print(f"Cloud request failed: {e}") + return self._fallback_selection() + + def _fallback_selection(self) -> str: + """Fallback to random selection when cloud service fails.""" + available = [p for p in self.prompts if p not in self.history] + if not available: + available = self.prompts + self.history.clear() + return random.choice(available) + + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + model_response: Optional[str] = None, + ) -> None: + try: + requests.post( + f"{self.api_url}/rl-update", + json={ + "previous_prompt": previous_prompt, + "current_prompt": current_prompt, + "reward": reward, + "model_response": model_response, + }, + headers=self.headers, + timeout=self.timeout, + ) + except requests.exceptions.RequestException as e: + print(f"Reward update failed: {e}") + + +class QLearningPromptSelector(PromptSelectionInterface): + """Q-Learning based prompt selector with exploration/exploitation tradeoff.""" + + def __init__( + self, + prompts: List[str], + learning_rate: float = 0.1, + discount_factor: float = 0.9, + initial_exploration: float = 1.0, + exploration_decay: float = 0.995, + min_exploration: float = 0.01, + history_size: int = 3, + ): + if not prompts: + raise ValueError("Prompts list cannot be empty") + + self.prompts = prompts + self.learning_rate = learning_rate + self.discount_factor = discount_factor + self.exploration_rate = initial_exploration + self.exploration_decay = exploration_decay + self.min_exploration = min_exploration + self.history: Deque[str] = deque(maxlen=history_size) + + # Initialize Q-table with small random values + self.q_table: Dict[str, Dict[str, float]] = { + state: { + action: np.random.uniform(0, 0.1) + for action in prompts + if action != state + } + for state in prompts + } + + def select_next_prompt( + self, current_prompt: str, model_response: Optional[str] = None + ) -> str: + self.history.append(current_prompt) + available = [a for a in self.prompts if a not in self.history] + + if not available: + available = self.prompts + self.history.clear() + + # Exploration-exploitation tradeoff + if np.random.random() < self.exploration_rate: + selected = random.choice(available) + else: + q_values = {a: self.q_table[current_prompt][a] for a in available} + selected = max(q_values, key=q_values.get) # type: ignore + + # Decay exploration rate + self.exploration_rate = max( + self.min_exploration, self.exploration_rate * self.exploration_decay + ) + return selected + + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + model_response: Optional[str] = None, + ) -> None: + if ( + previous_prompt not in self.q_table + or current_prompt not in self.q_table[previous_prompt] + ): + return + + # Calculate temporal difference error + max_future_q = max(self.q_table[current_prompt].values(), default=0.0) + td_target = reward + self.discount_factor * max_future_q + td_error = td_target - self.q_table[previous_prompt][current_prompt] + + # Update Q-value + self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error diff --git a/agentic_security/probe_data/modules/test_rl_model.py b/agentic_security/probe_data/modules/test_rl_model.py new file mode 100644 index 0000000..c79804a --- /dev/null +++ b/agentic_security/probe_data/modules/test_rl_model.py @@ -0,0 +1,156 @@ +import pytest +from collections import deque +from typing import List, Optional +from unittest.mock import patch, Mock +import numpy as np + +# Import the classes to be tested +from agentic_security.probe_data.modules.rl_model import ( + PromptSelectionInterface, + RandomPromptSelector, + CloudRLPromptSelector, + QLearningPromptSelector, +) + + +# Fixtures for reusable test data +@pytest.fixture +def dataset_prompts() -> List[str]: + return [ + "What is AI?", + "How does RL work?", + "Explain supervised learning.", + "What is reinforcement learning?", + ] + + +@pytest.fixture +def mock_requests() -> Mock: + with patch("requests.post") as mock_requests: + yield mock_requests + + +# Tests for RandomPromptSelector +class TestRandomPromptSelector: + def test_initialization(self, dataset_prompts): + selector = RandomPromptSelector(dataset_prompts) + assert selector.prompts == dataset_prompts + assert isinstance(selector.history, deque) + assert selector.history.maxlen == 3 + + def test_select_next_prompt_no_history(self, dataset_prompts): + selector = RandomPromptSelector(dataset_prompts) + current_prompt = "What is AI?" + next_prompt = selector.select_next_prompt(current_prompt) + assert next_prompt in dataset_prompts + assert next_prompt != current_prompt # Ensure no immediate repetition + + def test_select_next_prompt_with_history(self, dataset_prompts): + selector = RandomPromptSelector(dataset_prompts) + selector.history.extend(["What is AI?", "How does RL work?"]) + next_prompt = selector.select_next_prompt("Explain supervised learning.") + assert next_prompt not in selector.history + + def test_select_next_prompt_reset_history(self, dataset_prompts): + selector = RandomPromptSelector(dataset_prompts, history_size=2) + selector.history.extend(["What is AI?", "How does RL work?"]) + next_prompt = selector.select_next_prompt("Explain supervised learning.") + assert len(selector.history) == 2 + assert next_prompt in dataset_prompts + + def test_update_rewards_no_op(self, dataset_prompts): + selector = RandomPromptSelector(dataset_prompts) + selector.update_rewards("What is AI?", "How does RL work?", 1.0) + # No state changes expected + assert len(selector.history) == 0 + + +# Tests for CloudRLPromptSelector +class TestCloudRLPromptSelector: + def test_initialization(self, dataset_prompts): + selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token") + assert selector.prompts == dataset_prompts + assert selector.api_url == "http://example.com" + assert selector.headers == {"Authorization": "Bearer token"} + + def test_select_next_prompt_success(self, dataset_prompts, mock_requests): + mock_requests.return_value.status_code = 200 + mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"} + + selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token") + next_prompt = selector.select_next_prompt("How does RL work?") + assert next_prompt == "What is AI?" + mock_requests.assert_called_once() + + def test_update_rewards_success(self, dataset_prompts, mock_requests): + mock_requests.return_value.status_code = 200 + + selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token") + selector.update_rewards("What is AI?", "How does RL work?", 1.0) + mock_requests.assert_called_once() + + +# Tests for QLearningPromptSelector +class TestQLearningPromptSelector: + def test_initialization(self, dataset_prompts): + selector = QLearningPromptSelector(dataset_prompts) + assert selector.prompts == dataset_prompts + assert selector.exploration_rate == 1.0 + assert len(selector.q_table) == len(dataset_prompts) + assert all( + len(v) == len(dataset_prompts) - 1 for v in selector.q_table.values() + ) + + def test_select_next_prompt_exploration(self, dataset_prompts): + selector = QLearningPromptSelector(dataset_prompts, initial_exploration=1.0) + next_prompt = selector.select_next_prompt("What is AI?") + assert next_prompt in dataset_prompts + assert next_prompt != "What is AI?" + + def test_select_next_prompt_exploitation(self, dataset_prompts): + selector = QLearningPromptSelector(dataset_prompts, initial_exploration=0.0) + selector.q_table["What is AI?"]["How does RL work?"] = 10.0 # Set high Q-value + next_prompt = selector.select_next_prompt("What is AI?") + assert next_prompt == "How does RL work?" + + def test_update_rewards(self, dataset_prompts): + selector = QLearningPromptSelector(dataset_prompts) + selector.update_rewards("What is AI?", "How does RL work?", 1.0) + assert selector.q_table["What is AI?"]["How does RL work?"] > 0.0 + + def test_exploration_rate_decay(self, dataset_prompts): + selector = QLearningPromptSelector( + dataset_prompts, initial_exploration=1.0, exploration_decay=0.9 + ) + assert selector.exploration_rate == 1.0 + selector.select_next_prompt("What is AI?") + assert selector.exploration_rate == 0.9 + selector.select_next_prompt("How does RL work?") + assert selector.exploration_rate == 0.81 + + def test_min_exploration_rate(self, dataset_prompts): + selector = QLearningPromptSelector( + dataset_prompts, + initial_exploration=0.1, + exploration_decay=0.5, + min_exploration=0.05, + ) + selector.select_next_prompt("What is AI?") + assert selector.exploration_rate == 0.05 # Should not go below min_exploration + + +# Edge Cases and Error Handling +def test_empty_prompts(): + with pytest.raises(ValueError, match="Prompts list cannot be empty"): + RandomPromptSelector([]) + + +def test_cloud_rl_selector_invalid_url(dataset_prompts): + selector = CloudRLPromptSelector(dataset_prompts, "invalid_url", "token") + next_prompt = selector.select_next_prompt("What is AI?") + assert next_prompt in dataset_prompts # Should fallback to random selection + + +def test_q_learning_selector_invalid_reward(dataset_prompts): + selector = QLearningPromptSelector(dataset_prompts) + selector.update_rewards("What is AI?", "How does RL work?", np.nan)