Files
EvoSynth/jailbreak_toolbox/models/multithreaded_model.py
dongdongunique f3af94df0b first commit
2025-12-10 00:54:02 +08:00

136 lines
5.0 KiB
Python

# jailbreak_toolbox/models/implementations/multithreaded_model.py
import threading
import queue
import time
from typing import List, Dict, Any, Optional, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from .base_model import BaseModel
from ..core.registry import model_registry
@model_registry.register("multithreaded_model")
class MultiThreadedModel(BaseModel):
"""
Base class for models that support multi-threaded API calls.
This provides an efficient way to make multiple API calls in parallel,
with rate limiting to avoid API throttling.
"""
def __init__(
self,
max_workers: int = 5,
requests_per_minute: int = 60,
retry_attempts: int = 3,
retry_delay: float = 1.0,
**kwargs
):
"""
Initialize the multi-threaded model.
Args:
max_workers: Maximum number of worker threads
requests_per_minute: Maximum number of requests per minute
retry_attempts: Number of retry attempts on failure
retry_delay: Delay between retries in seconds
"""
super().__init__(**kwargs)
self.max_workers = max_workers
self.requests_per_minute = requests_per_minute
self.retry_attempts = retry_attempts
self.retry_delay = retry_delay
# Rate limiting
self.request_interval = 60.0 / requests_per_minute
self.last_request_time = 0
self.request_lock = threading.Lock()
def query_batch(
self,
inputs: List[Dict[str, Any]],
callback: Optional[Callable[[int, str], None]] = None
) -> List[str]:
"""
Send multiple queries in parallel using a thread pool.
Args:
inputs: List of input dictionaries, each containing:
- 'text': Text input for the model
- 'image': Optional image input (path or PIL Image)
- Other model-specific parameters
callback: Optional callback function to call when each result is ready
Function signature: callback(index, response)
Returns:
List of model responses in the same order as inputs
"""
results = [None] * len(inputs)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# Submit all jobs
future_to_index = {
executor.submit(
self._thread_safe_query,
input_dict.get('text', ''),
input_dict.get('image', None),
input_dict.get('maintain_history', False)
): i
for i, input_dict in enumerate(inputs)
}
# Process results as they complete
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
response = future.result()
results[index] = response
# Call callback if provided
if callback:
callback(index, response)
except Exception as e:
# Record error message as response
results[index] = f"Error: {str(e)}"
if callback:
callback(index, results[index])
return results
def _thread_safe_query(self, text_input: str = "", image_input: Any = None, maintain_history: bool = False) -> str:
"""
Thread-safe wrapper around the query method with rate limiting.
Args:
text_input: The text input for the model
image_input: Optional image input
maintain_history: Whether to maintain conversation history
Returns:
The model's response
"""
# Apply rate limiting
with self.request_lock:
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.request_interval:
sleep_time = self.request_interval - time_since_last
time.sleep(sleep_time)
self.last_request_time = time.time()
# Make the actual query with retries
for attempt in range(self.retry_attempts):
try:
return self.query(text_input, image_input, maintain_history)
except Exception as e:
if attempt < self.retry_attempts - 1:
time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
else:
raise e # Re-raise the exception on the last attempt
def query(self, text_input: str = "", image_input: Any = None, maintain_history: bool = False) -> str:
"""
Subclasses must implement this method.
"""
raise NotImplementedError("Subclasses must implement the query method")