mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
fix(linter):
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import random
|
||||
import requests
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import List, Optional, Deque, Dict
|
||||
from typing import Deque
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
class PromptSelectionInterface(ABC):
|
||||
@@ -11,7 +12,7 @@ class PromptSelectionInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
self, current_prompt: str, model_response: str | None = None
|
||||
) -> str:
|
||||
"""Selects the next prompt based on current state and optional model response."""
|
||||
pass
|
||||
@@ -22,7 +23,7 @@ class PromptSelectionInterface(ABC):
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
model_response: str | None = None,
|
||||
) -> None:
|
||||
"""Updates internal rewards based on the outcome of the last selected prompt."""
|
||||
pass
|
||||
@@ -31,14 +32,14 @@ class PromptSelectionInterface(ABC):
|
||||
class RandomPromptSelector(PromptSelectionInterface):
|
||||
"""Random prompt selector with cycle prevention using history."""
|
||||
|
||||
def __init__(self, prompts: List[str], history_size: int = 3):
|
||||
def __init__(self, prompts: list[str], history_size: int = 3):
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty")
|
||||
self.prompts = prompts
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
self, current_prompt: str, model_response: str | None = None
|
||||
) -> str:
|
||||
self.history.append(current_prompt)
|
||||
available = [p for p in self.prompts if p not in self.history]
|
||||
@@ -54,7 +55,7 @@ class RandomPromptSelector(PromptSelectionInterface):
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
model_response: str | None = None,
|
||||
) -> None:
|
||||
pass # No learning in random selection
|
||||
|
||||
@@ -64,7 +65,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompts: list[str],
|
||||
api_url: str,
|
||||
auth_token: str,
|
||||
history_size: int = 3,
|
||||
@@ -79,7 +80,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
self.timeout = timeout
|
||||
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
self, current_prompt: str, model_response: str | None = None
|
||||
) -> str:
|
||||
self.history.append(current_prompt)
|
||||
|
||||
@@ -113,7 +114,7 @@ class CloudRLPromptSelector(PromptSelectionInterface):
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
model_response: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
@@ -136,7 +137,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompts: List[str],
|
||||
prompts: list[str],
|
||||
learning_rate: float = 0.1,
|
||||
discount_factor: float = 0.9,
|
||||
initial_exploration: float = 1.0,
|
||||
@@ -156,7 +157,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
self.history: Deque[str] = deque(maxlen=history_size)
|
||||
|
||||
# Initialize Q-table with small random values
|
||||
self.q_table: Dict[str, Dict[str, float]] = {
|
||||
self.q_table: dict[str, dict[str, float]] = {
|
||||
state: {
|
||||
action: np.random.uniform(0, 0.1)
|
||||
for action in prompts
|
||||
@@ -166,7 +167,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
}
|
||||
|
||||
def select_next_prompt(
|
||||
self, current_prompt: str, model_response: Optional[str] = None
|
||||
self, current_prompt: str, model_response: str | None = None
|
||||
) -> str:
|
||||
self.history.append(current_prompt)
|
||||
available = [a for a in self.prompts if a not in self.history]
|
||||
@@ -193,7 +194,7 @@ class QLearningPromptSelector(PromptSelectionInterface):
|
||||
previous_prompt: str,
|
||||
current_prompt: str,
|
||||
reward: float,
|
||||
model_response: Optional[str] = None,
|
||||
model_response: str | None = None,
|
||||
) -> None:
|
||||
if (
|
||||
previous_prompt not in self.q_table
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import pytest
|
||||
from collections import deque
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# Import the classes to be tested
|
||||
from agentic_security.probe_data.modules.rl_model import (
|
||||
PromptSelectionInterface,
|
||||
RandomPromptSelector,
|
||||
CloudRLPromptSelector,
|
||||
QLearningPromptSelector,
|
||||
RandomPromptSelector,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures for reusable test data
|
||||
@pytest.fixture
|
||||
def dataset_prompts() -> List[str]:
|
||||
def dataset_prompts() -> list[str]:
|
||||
return [
|
||||
"What is AI?",
|
||||
"How does RL work?",
|
||||
|
||||
Reference in New Issue
Block a user