fix(rl model):

This commit is contained in:
Alexander Myasoedov
2025-02-07 00:14:44 +02:00
parent 81ff6656e1
commit 01c27302de
2 changed files with 38 additions and 32 deletions
+27 -31
View File
@@ -1,3 +1,4 @@
import os
import random
from abc import ABC, abstractmethod
from collections import deque
@@ -5,6 +6,9 @@ from typing import Deque
import numpy as np
import requests
from loguru import logger
AUTH_TOKEN: str = os.getenv("AS_TOKEN", "gh0-5f4a8ed2-37c6-4bd7-a0cf-7070eae8115b")
class PromptSelectionInterface(ABC):
@@ -15,6 +19,11 @@ class PromptSelectionInterface(ABC):
"""Selects the next prompt based on current state and guard result."""
pass
@abstractmethod
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
"""Selects the next prompts based on current state and guard result."""
pass
@abstractmethod
def update_rewards(
self,
@@ -36,6 +45,9 @@ class RandomPromptSelector(PromptSelectionInterface):
self.prompts = prompts
self.history: Deque[str] = deque(maxlen=history_size)
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
return [self.select_next_prompt(current_prompt, passed_guard)]
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
self.history.append(current_prompt)
available = [p for p in self.prompts if p not in self.history]
@@ -63,8 +75,8 @@ class CloudRLPromptSelector(PromptSelectionInterface):
self,
prompts: list[str],
api_url: str,
auth_token: str,
history_size: int = 3,
auth_token: str = AUTH_TOKEN,
history_size: int = 300,
timeout: int = 5,
):
if not prompts:
@@ -72,36 +84,30 @@ class CloudRLPromptSelector(PromptSelectionInterface):
self.prompts = prompts
self.api_url = api_url
self.headers = {"Authorization": f"Bearer {auth_token}"}
self.history: Deque[str] = deque(maxlen=history_size)
self.timeout = timeout
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
self.history.append(current_prompt)
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]:
return self.select_next_prompts(current_prompt, passed_guard)[0]
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> str:
try:
response = requests.post(
f"{self.api_url}/rl-select",
f"{self.api_url}/rl-model/select-next-prompt",
json={
"current_prompt": current_prompt,
"passed_guard": passed_guard,
"history": list(self.history),
},
headers=self.headers,
timeout=self.timeout,
)
response.raise_for_status()
return response.json().get("next_prompt", self._fallback_selection())
return response.json().get("next_prompts", [])
except requests.exceptions.RequestException as e:
print(f"Cloud request failed: {e}")
return self._fallback_selection()
logger.error(f"Cloud request failed: {e}")
return [self._fallback_selection()]
def _fallback_selection(self) -> str:
"""Fallback to random selection when cloud service fails."""
available = [p for p in self.prompts if p not in self.history]
if not available:
available = self.prompts
self.history.clear()
return random.choice(available)
return random.choice(self.prompts)
def update_rewards(
self,
@@ -110,20 +116,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
reward: float,
passed_guard: bool,
) -> None:
try:
requests.post(
f"{self.api_url}/rl-update",
json={
"previous_prompt": previous_prompt,
"current_prompt": current_prompt,
"reward": reward,
"passed_guard": passed_guard,
},
headers=self.headers,
timeout=self.timeout,
)
except requests.exceptions.RequestException as e:
print(f"Reward update failed: {e}")
...
class QLearningPromptSelector(PromptSelectionInterface):
@@ -137,7 +130,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
initial_exploration: float = 1.0,
exploration_decay: float = 0.995,
min_exploration: float = 0.01,
history_size: int = 3,
history_size: int = 300,
):
if not prompts:
raise ValueError("Prompts list cannot be empty")
@@ -160,6 +153,9 @@ class QLearningPromptSelector(PromptSelectionInterface):
for state in prompts
}
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
return [self.select_next_prompt(current_prompt, passed_guard)]
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
self.history.append(current_prompt)
available = [a for a in self.prompts if a not in self.history]
@@ -61,7 +61,7 @@ class TestCloudRLPromptSelector:
def test_select_next_prompt_success(self, dataset_prompts, mock_requests):
mock_requests.return_value.status_code = 200
mock_requests.return_value.json.return_value = {"next_prompt": "What is AI?"}
mock_requests.return_value.json.return_value = {"next_prompts": ["What is AI?"]}
selector = CloudRLPromptSelector(dataset_prompts, "http://example.com", "token")
next_prompt = selector.select_next_prompt(
@@ -76,6 +76,16 @@ class TestCloudRLPromptSelector:
next_prompt = selector.select_next_prompt("What is AI?", passed_guard=True)
assert next_prompt in dataset_prompts
def test_select_next_prompt_success_service(self, dataset_prompts):
selector = CloudRLPromptSelector(
dataset_prompts,
api_url="https://edge.metaheuristic.co",
)
next_prompt = selector.select_next_prompt(
"How does RL work?", passed_guard=True
)
assert next_prompt
# Tests for QLearningPromptSelector
class TestQLearningPromptSelector: