mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(rl model):
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
@@ -5,6 +6,9 @@ from typing import Deque
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
|
||||
|
||||
|
||||
class PromptSelectionInterface(ABC):
|
||||
@@ -15,6 +19,11 @@ class PromptSelectionInterface(ABC):
|
||||
"""Selects the next prompt based on current state and guard result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
"""Selects the next prompts based on current state and guard result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_rewards(
|
||||
self,
|
||||
@@ -36,6 +45,9 @@ class RandomPromptSelector(PromptSelectionInterface):
|
||||
self.prompts = prompts
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
return [self.select_next_prompt(current_prompt, passed_guard)]
|
||||
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
self.history.append(current_prompt)
|
||||
available = [p for p in self.prompts if p not in self.history]
|
||||
@@ -63,8 +75,8 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
self,
|
||||
prompts: list[str],
|
||||
api_url: str,
|
||||
auth_token: str,
|
||||
history_size: int = 3,
|
||||
auth_token: str = AUTH_TOKEN,
|
||||
history_size: int = 300,
|
||||
timeout: int = 5,
|
||||
):
|
||||
if not prompts:
|
||||
@@ -72,36 +84,30 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
self.prompts = prompts
|
||||
self.api_url = api_url
|
||||
self.headers = {"Authorization": f"Bearer {auth_token}"}
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
self.timeout = timeout
|
||||
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
self.history.append(current_prompt)
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
return self.select_next_prompts(current_prompt, passed_guard)[0]
|
||||
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_url}/rl-select",
|
||||
f"{self.api_url}/rl-model/select-next-prompt",
|
||||
json={
|
||||
"current_prompt": current_prompt,
|
||||
"passed_guard": passed_guard,
|
||||
"history": list(self.history),
|
||||
},
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("next_prompt", self._fallback_selection())
|
||||
return response.json().get("next_prompts", [])
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Cloud request failed: {e}")
|
||||
return self._fallback_selection()
|
||||
logger.error(f"Cloud request failed: {e}")
|
||||
return [self._fallback_selection()]
|
||||
|
||||
def _fallback_selection(self) -> str:
|
||||
"""Fallback to random selection when cloud service fails."""
|
||||
available = [p for p in self.prompts if p not in self.history]
|
||||
if not available:
|
||||
available = self.prompts
|
||||
self.history.clear()
|
||||
return random.choice(available)
|
||||
return random.choice(self.prompts)
|
||||
|
||||
def update_rewards(
|
||||
self,
|
||||
@@ -110,20 +116,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
reward: float,
|
||||
passed_guard: bool,
|
||||
) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{self.api_url}/rl-update",
|
||||
json={
|
||||
"previous_prompt": previous_prompt,
|
||||
"current_prompt": current_prompt,
|
||||
"reward": reward,
|
||||
"passed_guard": passed_guard,
|
||||
},
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Reward update failed: {e}")
|
||||
...
|
||||
|
||||
|
||||
class QLearningPromptSelector(PromptSelectionInterface):
|
||||
@@ -137,7 +130,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
initial_exploration: float = 1.0,
|
||||
exploration_decay: float = 0.995,
|
||||
min_exploration: float = 0.01,
|
||||
history_size: int = 3,
|
||||
history_size: int = 300,
|
||||
):
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty")
|
||||
@@ -160,6 +153,9 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
for state in prompts
|
||||
}
|
||||
|
||||
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
|
||||
return [self.select_next_prompt(current_prompt, passed_guard)]
|
||||
|
||||
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
|
||||
self.history.append(current_prompt)
|
||||
available = [a for a in self.prompts if a not in self.history]
|
||||
|
||||
@@ -61,7 +61,7 @@ class TestCloudRLPromptSelector:
|
||||
|
||||
def test_select_next_prompt_success(self, dataset_prompts, mock_requests):
|
||||
mock_requests.return_value.status_code = 200
|
||||
mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"}
|
||||
mock_requests.return_value.json.return_value = {"next_prompts": ["What is AI?"]}
|
||||
|
||||
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
|
||||
next_prompt = selector.select_next_prompt(
|
||||
@@ -76,6 +76,16 @@ class TestCloudRLPromptSelector:
|
||||
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
|
||||
assert next_prompt in dataset_prompts
|
||||
|
||||
def test_select_next_prompt_success_service(self, dataset_prompts):
|
||||
selector = CloudRLPromptSelector(
|
||||
dataset_prompts,
|
||||
api_url="https://edge.metaheuristic.co",
|
||||
)
|
||||
next_prompt = selector.select_next_prompt(
|
||||
"How does RL work?", passed_guard=True
|
||||
)
|
||||
assert next_prompt
|
||||
|
||||
|
||||
# Tests for QLearningPromptSelector
|
||||
class TestQLearningPromptSelector:
|
||||
|
||||
Reference in New Issue
Block a user