fix(linter):

This commit is contained in:
Alexander Myasoedov
2025-02-05 16:53:21 +02:00
parent 6a8e7633d9
commit b18427aa7e
2 changed files with 22 additions and 22 deletions
+17 -16
View File
@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
import numpy as np
import random
import requests
from abc import ABC, abstractmethod
from collections import deque
from typing import List, Optional, Deque, Dict
from typing import Deque
import numpy as np
import requests
class PromptSelectionInterface(ABC):
@@ -11,7 +12,7 @@ class PromptSelectionInterface(ABC):
@abstractmethod
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
self, current_prompt: str, model_response: str | None = None
) -> str:
"""Selects the next prompt based on current state and optional model response."""
pass
@@ -22,7 +23,7 @@ class PromptSelectionInterface(ABC):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
model_response: str | None = None,
) -> None:
"""Updates internal rewards based on the outcome of the last selected prompt."""
pass
@@ -31,14 +32,14 @@ class PromptSelectionInterface(ABC):
class RandomPromptSelector(PromptSelectionInterface):
"""Random prompt selector with cycle prevention using history."""
def __init__(self, prompts: List[str], history_size: int = 3):
def __init__(self, prompts: list[str], history_size: int = 3):
if not prompts:
raise ValueError("Prompts list cannot be empty")
self.prompts = prompts
self.history: Deque[str] = deque(maxlen=history_size)
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
self, current_prompt: str, model_response: str | None = None
) -> str:
self.history.append(current_prompt)
available = [p for p in self.prompts if p not in self.history]
@@ -54,7 +55,7 @@ class RandomPromptSelector(PromptSelectionInterface):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
model_response: str | None = None,
) -> None:
pass # No learning in random selection
@@ -64,7 +65,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
def __init__(
self,
prompts: List[str],
prompts: list[str],
api_url: str,
auth_token: str,
history_size: int = 3,
@@ -79,7 +80,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
self.timeout = timeout
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
self, current_prompt: str, model_response: str | None = None
) -> str:
self.history.append(current_prompt)
@@ -113,7 +114,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
model_response: str | None = None,
) -> None:
try:
requests.post(
@@ -136,7 +137,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
def __init__(
self,
prompts: List[str],
prompts: list[str],
learning_rate: float = 0.1,
discount_factor: float = 0.9,
initial_exploration: float = 1.0,
@@ -156,7 +157,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
self.history: Deque[str] = deque(maxlen=history_size)
# Initialize Q-table with small random values
self.q_table: Dict[str, Dict[str, float]] = {
self.q_table: dict[str, dict[str, float]] = {
state: {
action: np.random.uniform(0, 0.1)
for action in prompts
@@ -166,7 +167,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
}
def select_next_prompt(
self, current_prompt: str, model_response: Optional[str] = None
self, current_prompt: str, model_response: str | None = None
) -> str:
self.history.append(current_prompt)
available = [a for a in self.prompts if a not in self.history]
@@ -193,7 +194,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
previous_prompt: str,
current_prompt: str,
reward: float,
model_response: Optional[str] = None,
model_response: str | None = None,
) -> None:
if (
previous_prompt not in self.q_table
@@ -1,21 +1,20 @@
import pytest
from collections import deque
from typing import List, Optional
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch
import numpy as np
import pytest
# Import the classes to be tested
from agentic_security.probe_data.modules.rl_model import (
PromptSelectionInterface,
RandomPromptSelector,
CloudRLPromptSelector,
QLearningPromptSelector,
RandomPromptSelector,
)
# Fixtures for reusable test data
@pytest.fixture
def dataset_prompts() -> List[str]:
def dataset_prompts() -> list[str]:
return [
"What is AI?",
"How does RL work?",