mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(rl_model.Module):
This commit is contained in:
@@ -408,6 +408,21 @@ REGISTRY = REGISTRY_V0 + [
|
||||
},
|
||||
"modality": "text",
|
||||
},
|
||||
{
|
||||
"dataset_name": "Reinforcement Learning Optimization",
|
||||
"num_prompts": 0,
|
||||
"tokens": 0,
|
||||
"approx_cost": 0.0,
|
||||
"source": "Cloud hosted model",
|
||||
"selected": False,
|
||||
"url": "",
|
||||
"dynamic": True,
|
||||
"opts": {
|
||||
"port": 8718,
|
||||
"modules": ["encoding"],
|
||||
},
|
||||
"modality": "text",
|
||||
},
|
||||
{
|
||||
"dataset_name": "InspectAI",
|
||||
"num_prompts": 0,
|
||||
|
||||
@@ -16,6 +16,7 @@ from agentic_security.probe_data.modules import (
|
||||
fine_tuned,
|
||||
garak_tool,
|
||||
inspect_ai_tool,
|
||||
rl_model,
|
||||
)
|
||||
|
||||
|
||||
@@ -265,6 +266,11 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
|
||||
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
|
||||
"Reinforcement Learning Optimization",
|
||||
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
|
||||
lazy=True,
|
||||
),
|
||||
"InspectAI": lambda opts: dataset_from_iterator(
|
||||
"InspectAI",
|
||||
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import uuid as U
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
@@ -78,6 +80,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
auth_token: str = AUTH_TOKEN,
|
||||
history_size: int = 300,
|
||||
timeout: int = 5,
|
||||
run_id: str = "",
|
||||
):
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty")
|
||||
@@ -85,6 +88,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
self.api_url = api_url
|
||||
self.headers = {"Authorization": f"Bearer {auth_token}"}
|
||||
self.timeout = timeout
|
||||
self.run_id = run_id or U.uuid4().hex
|
||||
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
return self.select_next_prompts(current_prompt, passed_guard)[0]
|
||||
@@ -94,6 +98,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
response = requests.post(
|
||||
f"{self.api_url}/rl-model/select-next-prompt",
|
||||
json={
|
||||
"run_id": U.uuid4().hex,
|
||||
"current_prompt": current_prompt,
|
||||
"passed_guard": passed_guard,
|
||||
},
|
||||
@@ -115,8 +120,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class QLearningPromptSelector(PromptSelectionInterface):
|
||||
@@ -197,3 +201,46 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
|
||||
# Update Q-value
|
||||
self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error
|
||||
|
||||
|
||||
class Module:
|
||||
def __init__(
|
||||
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
|
||||
):
|
||||
self.tools_inbox = tools_inbox
|
||||
self.opts = opts
|
||||
self.prompt_groups = prompt_groups
|
||||
self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts
|
||||
self.run_id = U.uuid4().hex
|
||||
self.batch_size = self.opts.get("batch_size", 500)
|
||||
self.rl_model = CloudRLPromptSelector(
|
||||
prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
|
||||
)
|
||||
|
||||
async def apply(self):
|
||||
current_prompt = "What is AI?"
|
||||
passed_guard = False
|
||||
for _ in range(max(self.max_prompts, 1)):
|
||||
# Fetch prompts from the API
|
||||
prompts = await asyncio.to_thread(
|
||||
lambda: self.rl_model.select_next_prompts(
|
||||
current_prompt, passed_guard=passed_guard
|
||||
)
|
||||
)
|
||||
|
||||
if not prompts:
|
||||
logger.error("No prompts retrieved from the API.")
|
||||
return
|
||||
|
||||
logger.info(f"Retrieved {len(prompts)} prompts.")
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
logger.info(f"Processing prompt {i+1}/{len(prompts)}: {prompt}")
|
||||
yield prompt
|
||||
current_prompt = prompt
|
||||
while not self.tools_inbox.empty():
|
||||
ref = await self.tools_inbox.get()
|
||||
print(ref, "ref")
|
||||
message, _, ready = ref["message"], ref["reply"], ref["ready"]
|
||||
yield message
|
||||
ready.set()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -8,6 +9,7 @@ import requests
|
||||
# Import the classes to be tested
|
||||
from agentic_security.probe_data.modules.rl_model import (
|
||||
CloudRLPromptSelector,
|
||||
Module,
|
||||
QLearningPromptSelector,
|
||||
RandomPromptSelector,
|
||||
)
|
||||
@@ -30,6 +32,19 @@ def mock_requests() -> Mock:
|
||||
yield mock_requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rl_selector() -> Mock:
|
||||
return CloudRLPromptSelector(
|
||||
dataset_prompts,
|
||||
api_url="https://edge.metaheuristic.co",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools_inbox() -> asyncio.Queue:
|
||||
return asyncio.Queue()
|
||||
|
||||
|
||||
# Tests for RandomPromptSelector
|
||||
class TestRandomPromptSelector:
|
||||
def test_initialization(self, dataset_prompts):
|
||||
@@ -141,3 +156,60 @@ def test_cloud_rl_selector_invalid_url(dataset_prompts):
|
||||
def test_q_learning_selector_invalid_reward(dataset_prompts):
|
||||
selector = QLearningPromptSelector(dataset_prompts)
|
||||
selector.update_rewards("What is AI?", "How does RL work?", np.nan, True)
|
||||
|
||||
|
||||
# Tests for Module class
|
||||
class TestModule:
|
||||
@pytest.fixture
|
||||
def mock_uuid(self):
|
||||
with patch("uuid.uuid4") as mock:
|
||||
mock.return_value.hex = "test_run_id"
|
||||
yield mock
|
||||
|
||||
def test_initialization(self, dataset_prompts, tools_inbox, mock_uuid):
|
||||
module = Module(dataset_prompts, tools_inbox)
|
||||
assert module.prompt_groups == dataset_prompts
|
||||
assert module.tools_inbox == tools_inbox
|
||||
assert module.max_prompts == 2000
|
||||
assert module.batch_size == 500
|
||||
assert module.run_id == "test_run_id"
|
||||
assert isinstance(module.rl_model, CloudRLPromptSelector)
|
||||
|
||||
def test_initialization_with_options(self, dataset_prompts, tools_inbox, mock_uuid):
|
||||
opts = {
|
||||
"max_prompts": 100,
|
||||
"batch_size": 50,
|
||||
}
|
||||
module = Module(dataset_prompts, tools_inbox, opts)
|
||||
assert module.max_prompts == 100
|
||||
assert module.batch_size == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_basic_flow(
|
||||
self, dataset_prompts, tools_inbox, mock_rl_selector
|
||||
):
|
||||
module = Module(dataset_prompts, tools_inbox)
|
||||
|
||||
count = 0
|
||||
async for prompt in module.apply():
|
||||
assert prompt == "Test prompt"
|
||||
count += 1
|
||||
if count >= 3: # Test a few iterations
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox):
|
||||
# Add a test message to the tools inbox
|
||||
test_message = {
|
||||
"message": "Test message",
|
||||
"reply": None,
|
||||
"ready": asyncio.Event(),
|
||||
}
|
||||
await tools_inbox.put(test_message)
|
||||
|
||||
module = Module(dataset_prompts, tools_inbox)
|
||||
|
||||
async for output in module.apply():
|
||||
if output == "Test message":
|
||||
test_message["ready"].set()
|
||||
break
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import pytest
|
||||
import requests
|
||||
from agentic_security.probe_data import REGISTRY
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset", REGISTRY)
|
||||
def test_registry_accessibility(dataset):
|
||||
"""
|
||||
Validate that datasets from REGISTRY are accessible.
|
||||
- If it's a URL, check if the response status is 200.
|
||||
- If it's a cloud-hosted dataset, skip the test.
|
||||
"""
|
||||
dataset_name = dataset.get("dataset_name", "Unknown Dataset")
|
||||
dataset_url = dataset.get("url")
|
||||
|
||||
if not dataset_url:
|
||||
pytest.fail(f"Dataset {dataset_name} is missing a URL.")
|
||||
|
||||
if dataset_url.lower() == "cloud":
|
||||
pytest.skip(f"Skipping cloud dataset: {dataset_name}")
|
||||
|
||||
if isinstance(dataset_url, str) and dataset_url.startswith("http"):
|
||||
try:
|
||||
response = requests.head(
|
||||
dataset_url, timeout=5
|
||||
) # HEAD request for efficiency
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Dataset URL is inaccessible: {dataset_url}"
|
||||
except requests.exceptions.RequestException as e:
|
||||
pytest.fail(f"Request failed for {dataset_name} ({dataset_url}): {e}")
|
||||
|
||||
else:
|
||||
pytest.fail(f"Unexpected URL format for {dataset_name}: {dataset_url}")
|
||||
Reference in New Issue
Block a user