feat(add reinforcement_learning module):

This commit is contained in:
Alexander Myasoedov
2025-02-05 16:51:37 +02:00
parent 0bc4feef74
commit 6a8e7633d9
2 changed files with 366 additions and 0 deletions
@@ -0,0 +1,210 @@
from abc import ABC, abstractmethod
import numpy as np
import random
import requests
from collections import deque
from typing import List, Optional, Deque, Dict
class PromptSelectionInterface(ABC):
"""Abstract base class for prompt selection strategies."""
@abstractmethod
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
) -> str:
"""Selects the next prompt based on current state and optional model response."""
pass
@abstractmethod
def update_rewards(
self,
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
) -> None:
"""Updates internal rewards based on the outcome of the last selected prompt."""
pass
class RandomPromptSelector(PromptSelectionInterface):
"""Random prompt selector with cycle prevention using history."""
def __init__(self, prompts: List[str], history_size: int = 3):
if not prompts:
raise ValueError("Prompts list cannot be empty")
self.prompts = prompts
self.history: Deque[str] = deque(maxlen=history_size)
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
) -> str:
self.history.append(current_prompt)
available = [p for p in self.prompts if p not in self.history]
if not available:
available = self.prompts
self.history.clear()
return random.choice(available)
def update_rewards(
self,
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
) -> None:
pass # No learning in random selection
class CloudRLPromptSelector(PromptSelectionInterface):
"""Cloud-based reinforcement learning prompt selector with fallback."""
def __init__(
self,
prompts: List[str],
api_url: str,
auth_token: str,
history_size: int = 3,
timeout: int = 5,
):
if not prompts:
raise ValueError("Prompts list cannot be empty")
self.prompts = prompts
self.api_url = api_url
self.headers = {"Authorization": f"Bearer {auth_token}"}
self.history: Deque[str] = deque(maxlen=history_size)
self.timeout = timeout
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
) -> str:
self.history.append(current_prompt)
try:
response = requests.post(
f"{self.api_url}/rl-select",
json={
"current_prompt": current_prompt,
"model_response": model_response,
"history": list(self.history),
},
headers=self.headers,
timeout=self.timeout,
)
response.raise_for_status()
return response.json().get("next_prompt", self._fallback_selection())
except requests.exceptions.RequestException as e:
print(f"Cloud request failed: {e}")
return self._fallback_selection()
def _fallback_selection(self) -> str:
"""Fallback to random selection when cloud service fails."""
available = [p for p in self.prompts if p not in self.history]
if not available:
available = self.prompts
self.history.clear()
return random.choice(available)
def update_rewards(
self,
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
) -> None:
try:
requests.post(
f"{self.api_url}/rl-update",
json={
"previous_prompt": previous_prompt,
"current_prompt": current_prompt,
"reward": reward,
"model_response": model_response,
},
headers=self.headers,
timeout=self.timeout,
)
except requests.exceptions.RequestException as e:
print(f"Reward update failed: {e}")
class QLearningPromptSelector(PromptSelectionInterface):
"""Q-Learning based prompt selector with exploration/exploitation tradeoff."""
def __init__(
self,
prompts: List[str],
learning_rate: float = 0.1,
discount_factor: float = 0.9,
initial_exploration: float = 1.0,
exploration_decay: float = 0.995,
min_exploration: float = 0.01,
history_size: int = 3,
):
if not prompts:
raise ValueError("Prompts list cannot be empty")
self.prompts = prompts
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.exploration_rate = initial_exploration
self.exploration_decay = exploration_decay
self.min_exploration = min_exploration
self.history: Deque[str] = deque(maxlen=history_size)
# Initialize Q-table with small random values
self.q_table: Dict[str, Dict[str, float]] = {
state: {
action: np.random.uniform(0, 0.1)
for action in prompts
if action != state
}
for state in prompts
}
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
) -> str:
self.history.append(current_prompt)
available = [a for a in self.prompts if a not in self.history]
if not available:
available = self.prompts
self.history.clear()
# Exploration-exploitation tradeoff
if np.random.random() < self.exploration_rate:
selected = random.choice(available)
else:
q_values = {a: self.q_table[current_prompt][a] for a in available}
selected = max(q_values, key=q_values.get) # type: ignore
# Decay exploration rate
self.exploration_rate = max(
self.min_exploration, self.exploration_rate * self.exploration_decay
)
return selected
def update_rewards(
self,
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
) -> None:
if (
previous_prompt not in self.q_table
or current_prompt not in self.q_table[previous_prompt]
):
return
# Calculate temporal difference error
max_future_q = max(self.q_table[current_prompt].values(), default=0.0)
td_target = reward + self.discount_factor * max_future_q
td_error = td_target - self.q_table[previous_prompt][current_prompt]
# Update Q-value
self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error