mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
feat(add \Reinforcement Learning Optimization doc):
This commit is contained in:
@@ -50,6 +50,50 @@ The `probe_data` module is a core component of the Agentic Security project, res
|
||||
- `base64_encode(data)`: Encodes data in base64 format.
|
||||
- `mirror_words(text)`: Mirrors words in the text.
|
||||
|
||||
### rl_model.py
|
||||
|
||||
- **Classes:**
|
||||
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
|
||||
|
||||
- Methods:
|
||||
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
|
||||
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
|
||||
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards
|
||||
|
||||
- `RandomPromptSelector`: Basic random selection with history tracking.
|
||||
|
||||
- Parameters:
|
||||
- `prompts: list[str]`: List of available prompts
|
||||
- `history_size: int = 3`: Size of history to prevent cycles
|
||||
|
||||
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
|
||||
|
||||
- Parameters:
|
||||
- `prompts: list[str]`: List of available prompts
|
||||
- `api_url: str`: URL of RL service
|
||||
- `auth_token: str = AUTH_TOKEN`: Authentication token
|
||||
- `history_size: int = 300`: Size of history
|
||||
- `timeout: int = 5`: Request timeout
|
||||
- `run_id: str = ""`: Unique run identifier
|
||||
|
||||
- `QLearningPromptSelector`: Local Q-learning implementation.
|
||||
|
||||
- Parameters:
|
||||
- `prompts: list[str]`: List of available prompts
|
||||
- `learning_rate: float = 0.1`: Learning rate
|
||||
- `discount_factor: float = 0.9`: Discount factor
|
||||
- `initial_exploration: float = 1.0`: Initial exploration rate
|
||||
- `exploration_decay: float = 0.995`: Exploration decay rate
|
||||
- `min_exploration: float = 0.01`: Minimum exploration rate
|
||||
- `history_size: int = 300`: Size of history
|
||||
|
||||
- `Module`: Main class that uses CloudRLPromptSelector.
|
||||
|
||||
- Parameters:
|
||||
- `prompt_groups: list[str]`: Groups of prompts
|
||||
- `tools_inbox: asyncio.Queue`: Queue for tool communication
|
||||
- `opts: dict = {}`: Configuration options
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Generating Audio
|
||||
@@ -68,6 +112,19 @@ from agentic_security.probe_data.data import load_dataset_general
|
||||
dataset = load_dataset_general("example_dataset")
|
||||
```
|
||||
|
||||
### Using RL Model
|
||||
|
||||
```python
|
||||
from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector
|
||||
|
||||
prompts = ["What is AI?", "Explain machine learning"]
|
||||
selector = QLearningPromptSelector(prompts)
|
||||
|
||||
current_prompt = "What is AI?"
|
||||
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
|
||||
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True)
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
The `probe_data` module provides essential functionality for handling and transforming datasets within the Agentic Security project. This documentation serves as a guide to understanding and utilizing the module's capabilities.
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
# RL Model Module
|
||||
|
||||
The RL Model module provides reinforcement learning-based prompt selection strategies for the probe system.
|
||||
|
||||
## Overview
|
||||
|
||||
The module implements several prompt selection strategies that use reinforcement learning techniques to optimize prompt selection based on guard results and rewards.
|
||||
|
||||
## Classes
|
||||
|
||||
### PromptSelectionInterface
|
||||
|
||||
Abstract base class defining the interface for prompt selection strategies.
|
||||
|
||||
**Methods:**
|
||||
|
||||
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`
|
||||
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`
|
||||
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`
|
||||
|
||||
### RandomPromptSelector
|
||||
|
||||
Basic random selection strategy with cycle prevention using history.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `prompts`: List of available prompts
|
||||
- `history_size`: Size of history buffer to prevent cycles (default: 300)
|
||||
|
||||
### CloudRLPromptSelector
|
||||
|
||||
Cloud-based reinforcement learning prompt selector with fallback to random selection.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `prompts`: List of available prompts
|
||||
- `api_url`: URL of the RL service
|
||||
- `auth_token`: Authentication token (default: AS_TOKEN environment variable)
|
||||
- `history_size`: Size of history buffer (default: 300)
|
||||
- `timeout`: Request timeout in seconds (default: 5)
|
||||
- `run_id`: Unique identifier for the run
|
||||
|
||||
### QLearningPromptSelector
|
||||
|
||||
Q-Learning based prompt selector with exploration/exploitation tradeoff.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `prompts`: List of available prompts
|
||||
- `learning_rate`: Learning rate (default: 0.1)
|
||||
- `discount_factor`: Discount factor (default: 0.9)
|
||||
- `initial_exploration`: Initial exploration rate (default: 1.0)
|
||||
- `exploration_decay`: Exploration decay rate (default: 0.995)
|
||||
- `min_exploration`: Minimum exploration rate (default: 0.01)
|
||||
- `history_size`: Size of history buffer (default: 300)
|
||||
|
||||
### Module
|
||||
|
||||
Main class that implements the RL-based prompt selection functionality.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
- `prompt_groups`: List of prompt groups
|
||||
- `tools_inbox`: asyncio.Queue for tool communication
|
||||
- `opts`: Additional options
|
||||
- `max_prompts`: Maximum number of prompts to generate (default: 10)
|
||||
- `batch_size`: Batch size for processing (default: 500)
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
from agentic_security.probe_data.modules.rl_model import (
|
||||
Module,
|
||||
CloudRLPromptSelector,
|
||||
QLearningPromptSelector
|
||||
)
|
||||
|
||||
# Initialize with prompt groups
|
||||
prompt_groups = ["What is AI?", "Explain ML", "Describe RL"]
|
||||
module = Module(prompt_groups, asyncio.Queue())
|
||||
|
||||
# Use the module
|
||||
async for prompt in module.apply():
|
||||
print(f"Selected prompt: {prompt}")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### PromptSelectionInterface
|
||||
|
||||
```python
|
||||
class PromptSelectionInterface(ABC):
|
||||
@abstractmethod
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
"""Select next prompt based on current state and guard result."""
|
||||
|
||||
@abstractmethod
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
"""Select next prompts based on current state and guard result."""
|
||||
|
||||
@abstractmethod
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
"""Update internal rewards based on outcome of last selected prompt."""
|
||||
```
|
||||
|
||||
### RandomPromptSelector
|
||||
|
||||
```python
|
||||
class RandomPromptSelector(PromptSelectionInterface):
|
||||
def __init__(self, prompts: list[str], history_size: int = 300):
|
||||
"""Initialize with prompts and history size."""
|
||||
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
"""Select next prompt randomly with cycle prevention."""
|
||||
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
"""No learning in random selection."""
|
||||
```
|
||||
|
||||
### CloudRLPromptSelector
|
||||
|
||||
```python
|
||||
class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
prompts: list[str],
|
||||
api_url: str,
|
||||
auth_token: str = AUTH_TOKEN,
|
||||
history_size: int = 300,
|
||||
timeout: int = 5,
|
||||
run_id: str = "",
|
||||
):
|
||||
"""Initialize with cloud RL configuration."""
|
||||
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
"""Select next prompts using cloud RL with fallback."""
|
||||
|
||||
def _fallback_selection(self) -> str:
|
||||
"""Fallback to random selection if cloud request fails."""
|
||||
```
|
||||
|
||||
### QLearningPromptSelector
|
||||
|
||||
```python
|
||||
class QLearningPromptSelector(PromptSelectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
prompts: list[str],
|
||||
learning_rate: float = 0.1,
|
||||
discount_factor: float = 0.9,
|
||||
initial_exploration: float = 1.0,
|
||||
exploration_decay: float = 0.995,
|
||||
min_exploration: float = 0.01,
|
||||
history_size: int = 300,
|
||||
):
|
||||
"""Initialize Q-Learning configuration."""
|
||||
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
"""Select next prompt using Q-Learning with exploration/exploitation."""
|
||||
|
||||
def update_rewards(
|
||||
self,
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
"""Update Q-values based on reward."""
|
||||
```
|
||||
|
||||
### Module
|
||||
|
||||
```python
|
||||
class Module:
|
||||
def __init__(
|
||||
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
|
||||
):
|
||||
"""Initialize module with prompt groups and configuration."""
|
||||
|
||||
async def apply(self):
|
||||
"""Apply the RL model to generate prompts."""
|
||||
```
|
||||
Reference in New Issue
Block a user