From e0eed6fd92804797026a26211de5adefa44168b3 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Fri, 7 Feb 2025 00:54:10 +0200 Subject: [PATCH] fix(rl_model.Module): --- agentic_security/probe_data/__init__.py | 15 ++++ agentic_security/probe_data/data.py | 6 ++ .../probe_data/modules/rl_model.py | 51 ++++++++++++- .../probe_data/modules/test_rl_model.py | 72 +++++++++++++++++++ agentic_security/test_registry.py | 34 --------- 5 files changed, 142 insertions(+), 36 deletions(-) delete mode 100644 agentic_security/test_registry.py diff --git a/agentic_security/probe_data/__init__.py b/agentic_security/probe_data/__init__.py index a1998d0..adf2282 100644 --- a/agentic_security/probe_data/__init__.py +++ b/agentic_security/probe_data/__init__.py @@ -408,6 +408,21 @@ REGISTRY = REGISTRY_V0 + [ }, "modality": "text", }, + { + "dataset_name": "Reinforcement Learning Optimization", + "num_prompts": 0, + "tokens": 0, + "approx_cost": 0.0, + "source": "Cloud hosted model", + "selected": False, + "url": "", + "dynamic": True, + "opts": { + "port": 8718, + "modules": ["encoding"], + }, + "modality": "text", + }, { "dataset_name": "InspectAI", "num_prompts": 0, diff --git a/agentic_security/probe_data/data.py b/agentic_security/probe_data/data.py index 27c9bd3..c0ef2d5 100644 --- a/agentic_security/probe_data/data.py +++ b/agentic_security/probe_data/data.py @@ -16,6 +16,7 @@ from agentic_security.probe_data.modules import ( fine_tuned, garak_tool, inspect_ai_tool, + rl_model, ) @@ -265,6 +266,11 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]): garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(), lazy=True, ), + "Reinforcement Learning Optimization": lambda opts: dataset_from_iterator( + "Reinforcement Learning Optimization", + rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(), + lazy=True, + ), "InspectAI": lambda opts: dataset_from_iterator( "InspectAI", inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(), diff --git a/agentic_security/probe_data/modules/rl_model.py b/agentic_security/probe_data/modules/rl_model.py index 9448381..5b424cc 100644 --- a/agentic_security/probe_data/modules/rl_model.py +++ b/agentic_security/probe_data/modules/rl_model.py @@ -1,5 +1,7 @@ +import asyncio import os import random +import uuid as U from abc import ABC, abstractmethod from collections import deque from typing import Deque @@ -78,6 +80,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): auth_token: str = AUTH_TOKEN, history_size: int = 300, timeout: int = 5, + run_id: str = "", ): if not prompts: raise ValueError("Prompts list cannot be empty") @@ -85,6 +88,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): self.api_url = api_url self.headers = {"Authorization": f"Bearer {auth_token}"} self.timeout = timeout + self.run_id = run_id or U.uuid4().hex def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]: return self.select_next_prompts(current_prompt, passed_guard)[0] @@ -94,6 +98,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): response = requests.post( f"{self.api_url}/rl-model/select-next-prompt", json={ + "run_id": U.uuid4().hex, "current_prompt": current_prompt, "passed_guard": passed_guard, }, @@ -115,8 +120,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): current_prompt: str, reward: float, passed_guard: bool, - ) -> None: - ... + ) -> None: ... class QLearningPromptSelector(PromptSelectionInterface): @@ -197,3 +201,46 @@ class QLearningPromptSelector(PromptSelectionInterface): # Update Q-value self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error + + +class Module: + def __init__( + self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {} + ): + self.tools_inbox = tools_inbox + self.opts = opts + self.prompt_groups = prompt_groups + self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts + self.run_id = U.uuid4().hex + self.batch_size = self.opts.get("batch_size", 500) + self.rl_model = CloudRLPromptSelector( + prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id + ) + + async def apply(self): + current_prompt = "What is AI?" + passed_guard = False + for _ in range(max(self.max_prompts, 1)): + # Fetch prompts from the API + prompts = await asyncio.to_thread( + lambda: self.rl_model.select_next_prompts( + current_prompt, passed_guard=passed_guard + ) + ) + + if not prompts: + logger.error("No prompts retrieved from the API.") + return + + logger.info(f"Retrieved {len(prompts)} prompts.") + + for i, prompt in enumerate(prompts): + logger.info(f"Processing prompt {i+1}/{len(prompts)}: {prompt}") + yield prompt + current_prompt = prompt + while not self.tools_inbox.empty(): + ref = await self.tools_inbox.get() + print(ref, "ref") + message, _, ready = ref["message"], ref["reply"], ref["ready"] + yield message + ready.set() diff --git a/agentic_security/probe_data/modules/test_rl_model.py b/agentic_security/probe_data/modules/test_rl_model.py index f90162a..ab348bb 100644 --- a/agentic_security/probe_data/modules/test_rl_model.py +++ b/agentic_security/probe_data/modules/test_rl_model.py @@ -1,3 +1,4 @@ +import asyncio from collections import deque from unittest.mock import Mock, patch @@ -8,6 +9,7 @@ import requests # Import the classes to be tested from agentic_security.probe_data.modules.rl_model import ( CloudRLPromptSelector, + Module, QLearningPromptSelector, RandomPromptSelector, ) @@ -30,6 +32,19 @@ def mock_requests() -> Mock: yield mock_requests +@pytest.fixture +def mock_rl_selector() -> Mock: + return CloudRLPromptSelector( + dataset_prompts, + api_url="https://edge.metaheuristic.co", + ) + + +@pytest.fixture +def tools_inbox() -> asyncio.Queue: + return asyncio.Queue() + + # Tests for RandomPromptSelector class TestRandomPromptSelector: def test_initialization(self, dataset_prompts): @@ -141,3 +156,60 @@ def test_cloud_rl_selector_invalid_url(dataset_prompts): def test_q_learning_selector_invalid_reward(dataset_prompts): selector = QLearningPromptSelector(dataset_prompts) selector.update_rewards("What is AI?", "How does RL work?", np.nan, True) + + +# Tests for Module class +class TestModule: + @pytest.fixture + def mock_uuid(self): + with patch("uuid.uuid4") as mock: + mock.return_value.hex = "test_run_id" + yield mock + + def test_initialization(self, dataset_prompts, tools_inbox, mock_uuid): + module = Module(dataset_prompts, tools_inbox) + assert module.prompt_groups == dataset_prompts + assert module.tools_inbox == tools_inbox + assert module.max_prompts == 2000 + assert module.batch_size == 500 + assert module.run_id == "test_run_id" + assert isinstance(module.rl_model, CloudRLPromptSelector) + + def test_initialization_with_options(self, dataset_prompts, tools_inbox, mock_uuid): + opts = { + "max_prompts": 100, + "batch_size": 50, + } + module = Module(dataset_prompts, tools_inbox, opts) + assert module.max_prompts == 100 + assert module.batch_size == 50 + + @pytest.mark.asyncio + async def test_apply_basic_flow( + self, dataset_prompts, tools_inbox, mock_rl_selector + ): + module = Module(dataset_prompts, tools_inbox) + + count = 0 + async for prompt in module.apply(): + assert prompt == "Test prompt" + count += 1 + if count >= 3: # Test a few iterations + break + + @pytest.mark.asyncio + async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox): + # Add a test message to the tools inbox + test_message = { + "message": "Test message", + "reply": None, + "ready": asyncio.Event(), + } + await tools_inbox.put(test_message) + + module = Module(dataset_prompts, tools_inbox) + + async for output in module.apply(): + if output == "Test message": + test_message["ready"].set() + break diff --git a/agentic_security/test_registry.py b/agentic_security/test_registry.py deleted file mode 100644 index 91739a1..0000000 --- a/agentic_security/test_registry.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -import requests -from agentic_security.probe_data import REGISTRY - - -@pytest.mark.parametrize("dataset", REGISTRY) -def test_registry_accessibility(dataset): - """ - Validate that datasets from REGISTRY are accessible. - - If it's a URL, check if the response status is 200. - - If it's a cloud-hosted dataset, skip the test. - """ - dataset_name = dataset.get("dataset_name", "Unknown Dataset") - dataset_url = dataset.get("url") - - if not dataset_url: - pytest.fail(f"Dataset {dataset_name} is missing a URL.") - - if dataset_url.lower() == "cloud": - pytest.skip(f"Skipping cloud dataset: {dataset_name}") - - if isinstance(dataset_url, str) and dataset_url.startswith("http"): - try: - response = requests.head( - dataset_url, timeout=5 - ) # HEAD request for efficiency - assert ( - response.status_code == 200 - ), f"Dataset URL is inaccessible: {dataset_url}" - except requests.exceptions.RequestException as e: - pytest.fail(f"Request failed for {dataset_name} ({dataset_url}): {e}") - - else: - pytest.fail(f"Unexpected URL format for {dataset_name}: {dataset_url}")