fix(rl_model.Module):

This commit is contained in:
Alexander Myasoedov
2025-02-07 00:54:10 +02:00
parent 21c37b823d
commit e0eed6fd92
5 changed files with 142 additions and 36 deletions
+15
View File
@@ -408,6 +408,21 @@ REGISTRY = REGISTRY_V0 + [
},
"modality": "text",
},
{
"dataset_name": "Reinforcement Learning Optimization",
"num_prompts": 0,
"tokens": 0,
"approx_cost": 0.0,
"source": "Cloud hosted model",
"selected": False,
"url": "",
"dynamic": True,
"opts": {
"port": 8718,
"modules": ["encoding"],
},
"modality": "text",
},
{
"dataset_name": "InspectAI",
"num_prompts": 0,
+6
View File
@@ -16,6 +16,7 @@ from agentic_security.probe_data.modules import (
fine_tuned,
garak_tool,
inspect_ai_tool,
rl_model,
)
@@ -265,6 +266,11 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
"Reinforcement Learning Optimization",
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
lazy=True,
),
"InspectAI": lambda opts: dataset_from_iterator(
"InspectAI",
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
@@ -1,5 +1,7 @@
import asyncio
import os
import random
import uuid as U
from abc import ABC, abstractmethod
from collections import deque
from typing import Deque
@@ -78,6 +80,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
auth_token: str = AUTH_TOKEN,
history_size: int = 300,
timeout: int = 5,
run_id: str = "",
):
if not prompts:
raise ValueError("Prompts list cannot be empty")
@@ -85,6 +88,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
self.api_url = api_url
self.headers = {"Authorization": f"Bearer {auth_token}"}
self.timeout = timeout
self.run_id = run_id or U.uuid4().hex
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]:
return self.select_next_prompts(current_prompt, passed_guard)[0]
@@ -94,6 +98,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
response = requests.post(
f"{self.api_url}/rl-model/select-next-prompt",
json={
"run_id": U.uuid4().hex,
"current_prompt": current_prompt,
"passed_guard": passed_guard,
},
@@ -115,8 +120,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
current_prompt: str,
reward: float,
passed_guard: bool,
) -> None:
...
) -> None: ...
class QLearningPromptSelector(PromptSelectionInterface):
@@ -197,3 +201,46 @@ class QLearningPromptSelector(PromptSelectionInterface):
# Update Q-value
self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error
class Module:
def __init__(
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
):
self.tools_inbox = tools_inbox
self.opts = opts
self.prompt_groups = prompt_groups
self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts
self.run_id = U.uuid4().hex
self.batch_size = self.opts.get("batch_size", 500)
self.rl_model = CloudRLPromptSelector(
prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
)
async def apply(self):
current_prompt = "What is AI?"
passed_guard = False
for _ in range(max(self.max_prompts, 1)):
# Fetch prompts from the API
prompts = await asyncio.to_thread(
lambda: self.rl_model.select_next_prompts(
current_prompt, passed_guard=passed_guard
)
)
if not prompts:
logger.error("No prompts retrieved from the API.")
return
logger.info(f"Retrieved {len(prompts)} prompts.")
for i, prompt in enumerate(prompts):
logger.info(f"Processing prompt {i+1}/{len(prompts)}: {prompt}")
yield prompt
current_prompt = prompt
while not self.tools_inbox.empty():
ref = await self.tools_inbox.get()
print(ref, "ref")
message, _, ready = ref["message"], ref["reply"], ref["ready"]
yield message
ready.set()
@@ -1,3 +1,4 @@
import asyncio
from collections import deque
from unittest.mock import Mock, patch
@@ -8,6 +9,7 @@ import requests
# Import the classes to be tested
from agentic_security.probe_data.modules.rl_model import (
CloudRLPromptSelector,
Module,
QLearningPromptSelector,
RandomPromptSelector,
)
@@ -30,6 +32,19 @@ def mock_requests() -> Mock:
yield mock_requests
@pytest.fixture
def mock_rl_selector() -> Mock:
return CloudRLPromptSelector(
dataset_prompts,
api_url="https://edge.metaheuristic.co",
)
@pytest.fixture
def tools_inbox() -> asyncio.Queue:
return asyncio.Queue()
# Tests for RandomPromptSelector
class TestRandomPromptSelector:
def test_initialization(self, dataset_prompts):
@@ -141,3 +156,60 @@ def test_cloud_rl_selector_invalid_url(dataset_prompts):
def test_q_learning_selector_invalid_reward(dataset_prompts):
selector = QLearningPromptSelector(dataset_prompts)
selector.update_rewards("What is AI?", "How does RL work?", np.nan, True)
# Tests for Module class
class TestModule:
@pytest.fixture
def mock_uuid(self):
with patch("uuid.uuid4") as mock:
mock.return_value.hex = "test_run_id"
yield mock
def test_initialization(self, dataset_prompts, tools_inbox, mock_uuid):
module = Module(dataset_prompts, tools_inbox)
assert module.prompt_groups == dataset_prompts
assert module.tools_inbox == tools_inbox
assert module.max_prompts == 2000
assert module.batch_size == 500
assert module.run_id == "test_run_id"
assert isinstance(module.rl_model, CloudRLPromptSelector)
def test_initialization_with_options(self, dataset_prompts, tools_inbox, mock_uuid):
opts = {
"max_prompts": 100,
"batch_size": 50,
}
module = Module(dataset_prompts, tools_inbox, opts)
assert module.max_prompts == 100
assert module.batch_size == 50
@pytest.mark.asyncio
async def test_apply_basic_flow(
self, dataset_prompts, tools_inbox, mock_rl_selector
):
module = Module(dataset_prompts, tools_inbox)
count = 0
async for prompt in module.apply():
assert prompt == "Test prompt"
count += 1
if count >= 3: # Test a few iterations
break
@pytest.mark.asyncio
async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox):
# Add a test message to the tools inbox
test_message = {
"message": "Test message",
"reply": None,
"ready": asyncio.Event(),
}
await tools_inbox.put(test_message)
module = Module(dataset_prompts, tools_inbox)
async for output in module.apply():
if output == "Test message":
test_message["ready"].set()
break
-34
View File
@@ -1,34 +0,0 @@
import pytest
import requests
from agentic_security.probe_data import REGISTRY
@pytest.mark.parametrize("dataset", REGISTRY)
def test_registry_accessibility(dataset):
"""
Validate that datasets from REGISTRY are accessible.
- If it's a URL, check if the response status is 200.
- If it's a cloud-hosted dataset, skip the test.
"""
dataset_name = dataset.get("dataset_name", "Unknown Dataset")
dataset_url = dataset.get("url")
if not dataset_url:
pytest.fail(f"Dataset {dataset_name} is missing a URL.")
if dataset_url.lower() == "cloud":
pytest.skip(f"Skipping cloud dataset: {dataset_name}")
if isinstance(dataset_url, str) and dataset_url.startswith("http"):
try:
response = requests.head(
dataset_url, timeout=5
) # HEAD request for efficiency
assert (
response.status_code == 200
), f"Dataset URL is inaccessible: {dataset_url}"
except requests.exceptions.RequestException as e:
pytest.fail(f"Request failed for {dataset_name} ({dataset_url}): {e}")
else:
pytest.fail(f"Unexpected URL format for {dataset_name}: {dataset_url}")