From eb27f7bbaa8939f4a711c8aa8f00736ff80f4bb5 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Fri, 7 Feb 2025 01:02:12 +0200 Subject: [PATCH] feat(add \Reinforcement Learning Optimization doc): --- .../probe_data/modules/rl_model.py | 5 +- docs/probe_data.md | 57 +++++ docs/rl_model.md | 194 ++++++++++++++++++ mkdocs.yml | 1 + 4 files changed, 255 insertions(+), 2 deletions(-) create mode 100644 docs/rl_model.md diff --git a/agentic_security/probe_data/modules/rl_model.py b/agentic_security/probe_data/modules/rl_model.py index 5b424cc..e1b71ed 100644 --- a/agentic_security/probe_data/modules/rl_model.py +++ b/agentic_security/probe_data/modules/rl_model.py @@ -41,7 +41,7 @@ class PromptSelectionInterface(ABC): class RandomPromptSelector(PromptSelectionInterface): """Random prompt selector with cycle prevention using history.""" - def __init__(self, prompts: list[str], history_size: int = 3): + def __init__(self, prompts: list[str], history_size: int = 300): if not prompts: raise ValueError("Prompts list cannot be empty") self.prompts = prompts @@ -120,7 +120,8 @@ class CloudRLPromptSelector(PromptSelectionInterface): current_prompt: str, reward: float, passed_guard: bool, - ) -> None: ... + ) -> None: + ... class QLearningPromptSelector(PromptSelectionInterface): diff --git a/docs/probe_data.md b/docs/probe_data.md index dda6648..4a3b6c2 100644 --- a/docs/probe_data.md +++ b/docs/probe_data.md @@ -50,6 +50,50 @@ The `probe_data` module is a core component of the Agentic Security project, res - `base64_encode(data)`: Encodes data in base64 format. - `mirror_words(text)`: Mirrors words in the text. +### rl_model.py + +- **Classes:** + - `PromptSelectionInterface`: Abstract base class for prompt selection strategies. + + - Methods: + - `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt + - `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts + - `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards + + - `RandomPromptSelector`: Basic random selection with history tracking. + + - Parameters: + - `prompts: list[str]`: List of available prompts + - `history_size: int = 3`: Size of history to prevent cycles + + - `CloudRLPromptSelector`: Cloud-based RL implementation with fallback. + + - Parameters: + - `prompts: list[str]`: List of available prompts + - `api_url: str`: URL of RL service + - `auth_token: str = AUTH_TOKEN`: Authentication token + - `history_size: int = 300`: Size of history + - `timeout: int = 5`: Request timeout + - `run_id: str = ""`: Unique run identifier + + - `QLearningPromptSelector`: Local Q-learning implementation. + + - Parameters: + - `prompts: list[str]`: List of available prompts + - `learning_rate: float = 0.1`: Learning rate + - `discount_factor: float = 0.9`: Discount factor + - `initial_exploration: float = 1.0`: Initial exploration rate + - `exploration_decay: float = 0.995`: Exploration decay rate + - `min_exploration: float = 0.01`: Minimum exploration rate + - `history_size: int = 300`: Size of history + + - `Module`: Main class that uses CloudRLPromptSelector. + + - Parameters: + - `prompt_groups: list[str]`: Groups of prompts + - `tools_inbox: asyncio.Queue`: Queue for tool communication + - `opts: dict = {}`: Configuration options + ## Usage Examples ### Generating Audio @@ -68,6 +112,19 @@ from agentic_security.probe_data.data import load_dataset_general dataset = load_dataset_general("example_dataset") ``` +### Using RL Model + +```python +from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector + +prompts = ["What is AI?", "Explain machine learning"] +selector = QLearningPromptSelector(prompts) + +current_prompt = "What is AI?" +next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True) +selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True) +``` + ## Conclusion The `probe_data` module provides essential functionality for handling and transforming datasets within the Agentic Security project. This documentation serves as a guide to understanding and utilizing the module's capabilities. diff --git a/docs/rl_model.md b/docs/rl_model.md new file mode 100644 index 0000000..efc0b86 --- /dev/null +++ b/docs/rl_model.md @@ -0,0 +1,194 @@ +# RL Model Module + +The RL Model module provides reinforcement learning-based prompt selection strategies for the probe system. + +## Overview + +The module implements several prompt selection strategies that use reinforcement learning techniques to optimize prompt selection based on guard results and rewards. + +## Classes + +### PromptSelectionInterface + +Abstract base class defining the interface for prompt selection strategies. + +**Methods:** + +- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str` +- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]` +- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None` + +### RandomPromptSelector + +Basic random selection strategy with cycle prevention using history. + +**Configuration:** + +- `prompts`: List of available prompts +- `history_size`: Size of history buffer to prevent cycles (default: 300) + +### CloudRLPromptSelector + +Cloud-based reinforcement learning prompt selector with fallback to random selection. + +**Configuration:** + +- `prompts`: List of available prompts +- `api_url`: URL of the RL service +- `auth_token`: Authentication token (default: AS_TOKEN environment variable) +- `history_size`: Size of history buffer (default: 300) +- `timeout`: Request timeout in seconds (default: 5) +- `run_id`: Unique identifier for the run + +### QLearningPromptSelector + +Q-Learning based prompt selector with exploration/exploitation tradeoff. + +**Configuration:** + +- `prompts`: List of available prompts +- `learning_rate`: Learning rate (default: 0.1) +- `discount_factor`: Discount factor (default: 0.9) +- `initial_exploration`: Initial exploration rate (default: 1.0) +- `exploration_decay`: Exploration decay rate (default: 0.995) +- `min_exploration`: Minimum exploration rate (default: 0.01) +- `history_size`: Size of history buffer (default: 300) + +### Module + +Main class that implements the RL-based prompt selection functionality. + +**Configuration:** + +- `prompt_groups`: List of prompt groups +- `tools_inbox`: asyncio.Queue for tool communication +- `opts`: Additional options + - `max_prompts`: Maximum number of prompts to generate (default: 10) + - `batch_size`: Batch size for processing (default: 500) + +## Usage Example + +```python +from agentic_security.probe_data.modules.rl_model import ( + Module, + CloudRLPromptSelector, + QLearningPromptSelector +) + +# Initialize with prompt groups +prompt_groups = ["What is AI?", "Explain ML", "Describe RL"] +module = Module(prompt_groups, asyncio.Queue()) + +# Use the module +async for prompt in module.apply(): + print(f"Selected prompt: {prompt}") +``` + +## API Reference + +### PromptSelectionInterface + +```python +class PromptSelectionInterface(ABC): + @abstractmethod + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: + """Select next prompt based on current state and guard result.""" + + @abstractmethod + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: + """Select next prompts based on current state and guard result.""" + + @abstractmethod + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + passed_guard: bool, + ) -> None: + """Update internal rewards based on outcome of last selected prompt.""" +``` + +### RandomPromptSelector + +```python +class RandomPromptSelector(PromptSelectionInterface): + def __init__(self, prompts: list[str], history_size: int = 300): + """Initialize with prompts and history size.""" + + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: + """Select next prompt randomly with cycle prevention.""" + + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + passed_guard: bool, + ) -> None: + """No learning in random selection.""" +``` + +### CloudRLPromptSelector + +```python +class CloudRLPromptSelector(PromptSelectionInterface): + def __init__( + self, + prompts: list[str], + api_url: str, + auth_token: str = AUTH_TOKEN, + history_size: int = 300, + timeout: int = 5, + run_id: str = "", + ): + """Initialize with cloud RL configuration.""" + + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: + """Select next prompts using cloud RL with fallback.""" + + def _fallback_selection(self) -> str: + """Fallback to random selection if cloud request fails.""" +``` + +### QLearningPromptSelector + +```python +class QLearningPromptSelector(PromptSelectionInterface): + def __init__( + self, + prompts: list[str], + learning_rate: float = 0.1, + discount_factor: float = 0.9, + initial_exploration: float = 1.0, + exploration_decay: float = 0.995, + min_exploration: float = 0.01, + history_size: int = 300, + ): + """Initialize Q-Learning configuration.""" + + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: + """Select next prompt using Q-Learning with exploration/exploitation.""" + + def update_rewards( + self, + previous_prompt: str, + current_prompt: str, + reward: float, + passed_guard: bool, + ) -> None: + """Update Q-values based on reward.""" +``` + +### Module + +```python +class Module: + def __init__( + self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {} + ): + """Initialize module with prompt groups and configuration.""" + + async def apply(self): + """Apply the RL model to generate prompts.""" +``` diff --git a/mkdocs.yml b/mkdocs.yml index 1464d98..9810b67 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,7 @@ nav: - Bayesian Optimization: optimizer.md - Image Generation: image_generation.md - Stenography Functions: stenography.md + - Reinforcement Learning Optimization: rl_model.md - Reference: - API Reference: api_reference.md - Community: