From 01c27302def0a7dee46fea7e3da1e3627327786a Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Fri, 7 Feb 2025 00:14:44 +0200 Subject: [PATCH] fix(rl model): --- .../probe_data/modules/rl_model.py | 58 +++++++++---------- .../probe_data/modules/test_rl_model.py | 12 +++- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/agentic_security/probe_data/modules/rl_model.py b/agentic_security/probe_data/modules/rl_model.py index f6e4dc2..9448381 100644 --- a/agentic_security/probe_data/modules/rl_model.py +++ b/agentic_security/probe_data/modules/rl_model.py @@ -1,3 +1,4 @@ +import os import random from abc import ABC, abstractmethod from collections import deque @@ -5,6 +6,9 @@ from typing import Deque import numpy as np import requests +from loguru import logger + +AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b") class PromptSelectionInterface(ABC): @@ -15,6 +19,11 @@ class PromptSelectionInterface(ABC): """Selects the next prompt based on current state and guard result.""" pass + @abstractmethod + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: + """Selects the next prompts based on current state and guard result.""" + pass + @abstractmethod def update_rewards( self, @@ -36,6 +45,9 @@ class RandomPromptSelector(PromptSelectionInterface): self.prompts = prompts self.history: Deque[str] = deque(maxlen=history_size) + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: + return [self.select_next_prompt(current_prompt, passed_guard)] + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: self.history.append(current_prompt) available = [p for p in self.prompts if p not in self.history] @@ -63,8 +75,8 @@ class CloudRLPromptSelector(PromptSelectionInterface): self, prompts: list[str], api_url: str, - auth_token: str, - history_size: int = 3, + auth_token: str = AUTH_TOKEN, + history_size: int = 300, timeout: int = 5, ): if not prompts: @@ -72,36 +84,30 @@ class CloudRLPromptSelector(PromptSelectionInterface): self.prompts = prompts self.api_url = api_url self.headers = {"Authorization": f"Bearer {auth_token}"} - self.history: Deque[str] = deque(maxlen=history_size) self.timeout = timeout - def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: - self.history.append(current_prompt) + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]: + return self.select_next_prompts(current_prompt, passed_guard)[0] + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> str: try: response = requests.post( - f"{self.api_url}/rl-select", + f"{self.api_url}/rl-model/select-next-prompt", json={ "current_prompt": current_prompt, "passed_guard": passed_guard, - "history": list(self.history), }, headers=self.headers, timeout=self.timeout, ) response.raise_for_status() - return response.json().get("next_prompt", self._fallback_selection()) + return response.json().get("next_prompts", []) except requests.exceptions.RequestException as e: - print(f"Cloud request failed: {e}") - return self._fallback_selection() + logger.error(f"Cloud request failed: {e}") + return [self._fallback_selection()] def _fallback_selection(self) -> str: - """Fallback to random selection when cloud service fails.""" - available = [p for p in self.prompts if p not in self.history] - if not available: - available = self.prompts - self.history.clear() - return random.choice(available) + return random.choice(self.prompts) def update_rewards( self, @@ -110,20 +116,7 @@ class CloudRLPromptSelector(PromptSelectionInterface): reward: float, passed_guard: bool, ) -> None: - try: - requests.post( - f"{self.api_url}/rl-update", - json={ - "previous_prompt": previous_prompt, - "current_prompt": current_prompt, - "reward": reward, - "passed_guard": passed_guard, - }, - headers=self.headers, - timeout=self.timeout, - ) - except requests.exceptions.RequestException as e: - print(f"Reward update failed: {e}") + ... class QLearningPromptSelector(PromptSelectionInterface): @@ -137,7 +130,7 @@ class QLearningPromptSelector(PromptSelectionInterface): initial_exploration: float = 1.0, exploration_decay: float = 0.995, min_exploration: float = 0.01, - history_size: int = 3, + history_size: int = 300, ): if not prompts: raise ValueError("Prompts list cannot be empty") @@ -160,6 +153,9 @@ class QLearningPromptSelector(PromptSelectionInterface): for state in prompts } + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: + return [self.select_next_prompt(current_prompt, passed_guard)] + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: self.history.append(current_prompt) available = [a for a in self.prompts if a not in self.history] diff --git a/agentic_security/probe_data/modules/test_rl_model.py b/agentic_security/probe_data/modules/test_rl_model.py index bf8bc7c..f90162a 100644 --- a/agentic_security/probe_data/modules/test_rl_model.py +++ b/agentic_security/probe_data/modules/test_rl_model.py @@ -61,7 +61,7 @@ class TestCloudRLPromptSelector: def test_select_next_prompt_success(self, dataset_prompts, mock_requests): mock_requests.return_value.status_code = 200 - mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"} + mock_requests.return_value.json.return_value = {"next_prompts": ["What is AI?"]} selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token") next_prompt = selector.select_next_prompt( @@ -76,6 +76,16 @@ class TestCloudRLPromptSelector: next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True) assert next_prompt in dataset_prompts + def test_select_next_prompt_success_service(self, dataset_prompts): + selector = CloudRLPromptSelector( + dataset_prompts, + api_url="https://edge.metaheuristic.co", + ) + next_prompt = selector.select_next_prompt( + "How does RL work?", passed_guard=True + ) + assert next_prompt + # Tests for QLearningPromptSelector class TestQLearningPromptSelector: