mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
feat(add reinforcement_learning module):
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import random
|
||||
import requests
|
||||
from collections import deque
|
||||
from typing import List, Optional, Deque, Dict
|
||||
|
||||
|
||||
class PromptSelectionInterface(ABC):
|
||||
"""Abstract base class for prompt selection strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
) -> str:
|
||||
"""Selects the next prompt based on current state and optional model response."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Updates internal rewards based on the outcome of the last selected prompt."""
|
||||
pass
|
||||
|
||||
|
||||
class RandomPromptSelector(PromptSelectionInterface):
|
||||
"""Random prompt selector with cycle prevention using history."""
|
||||
|
||||
def __init__(self, prompts: List[str], history_size: int = 3):
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty")
|
||||
self.prompts = prompts
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
) -> str:
|
||||
self.history.append(current_prompt)
|
||||
available = [p for p in self.prompts if p not in self.history]
|
||||
|
||||
if not available:
|
||||
available = self.prompts
|
||||
self.history.clear()
|
||||
|
||||
return random.choice(available)
|
||||
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
) -> None:
|
||||
pass # No learning in random selection
|
||||
|
||||
|
||||
class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
"""Cloud-based reinforcement learning prompt selector with fallback."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompts: List[str],
|
||||
api_url: str,
|
||||
auth_token: str,
|
||||
history_size: int = 3,
|
||||
timeout: int = 5,
|
||||
):
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty")
|
||||
self.prompts = prompts
|
||||
self.api_url = api_url
|
||||
self.headers = {"Authorization": f"Bearer {auth_token}"}
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
self.timeout = timeout
|
||||
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
) -> str:
|
||||
self.history.append(current_prompt)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_url}/rl-select",
|
||||
json={
|
||||
"current_prompt": current_prompt,
|
||||
"model_response": model_response,
|
||||
"history": list(self.history),
|
||||
},
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("next_prompt", self._fallback_selection())
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Cloud request failed: {e}")
|
||||
return self._fallback_selection()
|
||||
|
||||
def _fallback_selection(self) -> str:
|
||||
"""Fallback to random selection when cloud service fails."""
|
||||
available = [p for p in self.prompts if p not in self.history]
|
||||
if not available:
|
||||
available = self.prompts
|
||||
self.history.clear()
|
||||
return random.choice(available)
|
||||
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{self.api_url}/rl-update",
|
||||
json={
|
||||
"previous_prompt": previous_prompt,
|
||||
"current_prompt": current_prompt,
|
||||
"reward": reward,
|
||||
"model_response": model_response,
|
||||
},
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Reward update failed: {e}")
|
||||
|
||||
|
||||
class QLearningPromptSelector(PromptSelectionInterface):
|
||||
"""Q-Learning based prompt selector with exploration/exploitation tradeoff."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompts: List[str],
|
||||
learning_rate: float = 0.1,
|
||||
discount_factor: float = 0.9,
|
||||
initial_exploration: float = 1.0,
|
||||
exploration_decay: float = 0.995,
|
||||
min_exploration: float = 0.01,
|
||||
history_size: int = 3,
|
||||
):
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty")
|
||||
|
||||
self.prompts = prompts
|
||||
self.learning_rate = learning_rate
|
||||
self.discount_factor = discount_factor
|
||||
self.exploration_rate = initial_exploration
|
||||
self.exploration_decay = exploration_decay
|
||||
self.min_exploration = min_exploration
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
|
||||
# Initialize Q-table with small random values
|
||||
self.q_table: Dict[str, Dict[str, float]] = {
|
||||
state: {
|
||||
action: np.random.uniform(0, 0.1)
|
||||
for action in prompts
|
||||
if action != state
|
||||
}
|
||||
for state in prompts
|
||||
}
|
||||
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
) -> str:
|
||||
self.history.append(current_prompt)
|
||||
available = [a for a in self.prompts if a not in self.history]
|
||||
|
||||
if not available:
|
||||
available = self.prompts
|
||||
self.history.clear()
|
||||
|
||||
# Exploration-exploitation tradeoff
|
||||
if np.random.random() < self.exploration_rate:
|
||||
selected = random.choice(available)
|
||||
else:
|
||||
q_values = {a: self.q_table[current_prompt][a] for a in available}
|
||||
selected = max(q_values, key=q_values.get) # type: ignore
|
||||
|
||||
# Decay exploration rate
|
||||
self.exploration_rate = max(
|
||||
self.min_exploration, self.exploration_rate * self.exploration_decay
|
||||
)
|
||||
return selected
|
||||
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
) -> None:
|
||||
if (
|
||||
previous_prompt not in self.q_table
|
||||
or current_prompt not in self.q_table[previous_prompt]
|
||||
):
|
||||
return
|
||||
|
||||
# Calculate temporal difference error
|
||||
max_future_q = max(self.q_table[current_prompt].values(), default=0.0)
|
||||
td_target = reward + self.discount_factor * max_future_q
|
||||
td_error = td_target - self.q_table[previous_prompt][current_prompt]
|
||||
|
||||
# Update Q-value
|
||||
self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error
|
||||
@@ -0,0 +1,156 @@
|
||||
import pytest
|
||||
from collections import deque
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch, Mock
|
||||
import numpy as np
|
||||
|
||||
# Import the classes to be tested
|
||||
from agentic_security.probe_data.modules.rl_model import (
|
||||
PromptSelectionInterface,
|
||||
RandomPromptSelector,
|
||||
CloudRLPromptSelector,
|
||||
QLearningPromptSelector,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures for reusable test data
|
||||
@pytest.fixture
|
||||
def dataset_prompts() -> List[str]:
|
||||
return [
|
||||
"What is AI?",
|
||||
"How does RL work?",
|
||||
"Explain supervised learning.",
|
||||
"What is reinforcement learning?",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests() -> Mock:
|
||||
with patch("requests.post") as mock_requests:
|
||||
yield mock_requests
|
||||
|
||||
|
||||
# Tests for RandomPromptSelector
|
||||
class TestRandomPromptSelector:
|
||||
def test_initialization(self, dataset_prompts):
|
||||
selector = RandomPromptSelector(dataset_prompts)
|
||||
assert selector.prompts == dataset_prompts
|
||||
assert isinstance(selector.history, deque)
|
||||
assert selector.history.maxlen == 3
|
||||
|
||||
def test_select_next_prompt_no_history(self, dataset_prompts):
|
||||
selector = RandomPromptSelector(dataset_prompts)
|
||||
current_prompt = "What is AI?"
|
||||
next_prompt = selector.select_next_prompt(current_prompt)
|
||||
assert next_prompt in dataset_prompts
|
||||
assert next_prompt != current_prompt # Ensure no immediate repetition
|
||||
|
||||
def test_select_next_prompt_with_history(self, dataset_prompts):
|
||||
selector = RandomPromptSelector(dataset_prompts)
|
||||
selector.history.extend(["What is AI?", "How does RL work?"])
|
||||
next_prompt = selector.select_next_prompt("Explain supervised learning.")
|
||||
assert next_prompt not in selector.history
|
||||
|
||||
def test_select_next_prompt_reset_history(self, dataset_prompts):
|
||||
selector = RandomPromptSelector(dataset_prompts, history_size=2)
|
||||
selector.history.extend(["What is AI?", "How does RL work?"])
|
||||
next_prompt = selector.select_next_prompt("Explain supervised learning.")
|
||||
assert len(selector.history) == 2
|
||||
assert next_prompt in dataset_prompts
|
||||
|
||||
def test_update_rewards_no_op(self, dataset_prompts):
|
||||
selector = RandomPromptSelector(dataset_prompts)
|
||||
selector.update_rewards("What is AI?", "How does RL work?", 1.0)
|
||||
# No state changes expected
|
||||
assert len(selector.history) == 0
|
||||
|
||||
|
||||
# Tests for CloudRLPromptSelector
|
||||
class TestCloudRLPromptSelector:
|
||||
def test_initialization(self, dataset_prompts):
|
||||
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
|
||||
assert selector.prompts == dataset_prompts
|
||||
assert selector.api_url == "http://example.com"
|
||||
assert selector.headers == {"Authorization": "Bearer token"}
|
||||
|
||||
def test_select_next_prompt_success(self, dataset_prompts, mock_requests):
|
||||
mock_requests.return_value.status_code = 200
|
||||
mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"}
|
||||
|
||||
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
|
||||
next_prompt = selector.select_next_prompt("How does RL work?")
|
||||
assert next_prompt == "What is AI?"
|
||||
mock_requests.assert_called_once()
|
||||
|
||||
def test_update_rewards_success(self, dataset_prompts, mock_requests):
|
||||
mock_requests.return_value.status_code = 200
|
||||
|
||||
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
|
||||
selector.update_rewards("What is AI?", "How does RL work?", 1.0)
|
||||
mock_requests.assert_called_once()
|
||||
|
||||
|
||||
# Tests for QLearningPromptSelector
|
||||
class TestQLearningPromptSelector:
|
||||
def test_initialization(self, dataset_prompts):
|
||||
selector = QLearningPromptSelector(dataset_prompts)
|
||||
assert selector.prompts == dataset_prompts
|
||||
assert selector.exploration_rate == 1.0
|
||||
assert len(selector.q_table) == len(dataset_prompts)
|
||||
assert all(
|
||||
len(v) == len(dataset_prompts) - 1 for v in selector.q_table.values()
|
||||
)
|
||||
|
||||
def test_select_next_prompt_exploration(self, dataset_prompts):
|
||||
selector = QLearningPromptSelector(dataset_prompts, initial_exploration=1.0)
|
||||
next_prompt = selector.select_next_prompt("What is AI?")
|
||||
assert next_prompt in dataset_prompts
|
||||
assert next_prompt != "What is AI?"
|
||||
|
||||
def test_select_next_prompt_exploitation(self, dataset_prompts):
|
||||
selector = QLearningPromptSelector(dataset_prompts, initial_exploration=0.0)
|
||||
selector.q_table["What is AI?"]["How does RL work?"] = 10.0 # Set high Q-value
|
||||
next_prompt = selector.select_next_prompt("What is AI?")
|
||||
assert next_prompt == "How does RL work?"
|
||||
|
||||
def test_update_rewards(self, dataset_prompts):
|
||||
selector = QLearningPromptSelector(dataset_prompts)
|
||||
selector.update_rewards("What is AI?", "How does RL work?", 1.0)
|
||||
assert selector.q_table["What is AI?"]["How does RL work?"] > 0.0
|
||||
|
||||
def test_exploration_rate_decay(self, dataset_prompts):
|
||||
selector = QLearningPromptSelector(
|
||||
dataset_prompts, initial_exploration=1.0, exploration_decay=0.9
|
||||
)
|
||||
assert selector.exploration_rate == 1.0
|
||||
selector.select_next_prompt("What is AI?")
|
||||
assert selector.exploration_rate == 0.9
|
||||
selector.select_next_prompt("How does RL work?")
|
||||
assert selector.exploration_rate == 0.81
|
||||
|
||||
def test_min_exploration_rate(self, dataset_prompts):
|
||||
selector = QLearningPromptSelector(
|
||||
dataset_prompts,
|
||||
initial_exploration=0.1,
|
||||
exploration_decay=0.5,
|
||||
min_exploration=0.05,
|
||||
)
|
||||
selector.select_next_prompt("What is AI?")
|
||||
assert selector.exploration_rate == 0.05 # Should not go below min_exploration
|
||||
|
||||
|
||||
# Edge Cases and Error Handling
|
||||
def test_empty_prompts():
|
||||
with pytest.raises(ValueError, match="Prompts list cannot be empty"):
|
||||
RandomPromptSelector([])
|
||||
|
||||
|
||||
def test_cloud_rl_selector_invalid_url(dataset_prompts):
|
||||
selector = CloudRLPromptSelector(dataset_prompts, "invalid_url", "token")
|
||||
next_prompt = selector.select_next_prompt("What is AI?")
|
||||
assert next_prompt in dataset_prompts # Should fallback to random selection
|
||||
|
||||
|
||||
def test_q_learning_selector_invalid_reward(dataset_prompts):
|
||||
selector = QLearningPromptSelector(dataset_prompts)
|
||||
selector.update_rewards("What is AI?", "How does RL work?", np.nan)
|
||||
Reference in New Issue
Block a user