mirror of
https://github.com/dongdongunique/EvoSynth.git
synced 2026-02-12 17:22:44 +00:00
136 lines
5.0 KiB
Python
136 lines
5.0 KiB
Python
# jailbreak_toolbox/models/implementations/multithreaded_model.py
|
|
import threading
|
|
import queue
|
|
import time
|
|
from typing import List, Dict, Any, Optional, Callable
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from .base_model import BaseModel
|
|
from ..core.registry import model_registry
|
|
|
|
@model_registry.register("multithreaded_model")
|
|
class MultiThreadedModel(BaseModel):
|
|
"""
|
|
Base class for models that support multi-threaded API calls.
|
|
|
|
This provides an efficient way to make multiple API calls in parallel,
|
|
with rate limiting to avoid API throttling.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_workers: int = 5,
|
|
requests_per_minute: int = 60,
|
|
retry_attempts: int = 3,
|
|
retry_delay: float = 1.0,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Initialize the multi-threaded model.
|
|
|
|
Args:
|
|
max_workers: Maximum number of worker threads
|
|
requests_per_minute: Maximum number of requests per minute
|
|
retry_attempts: Number of retry attempts on failure
|
|
retry_delay: Delay between retries in seconds
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.max_workers = max_workers
|
|
self.requests_per_minute = requests_per_minute
|
|
self.retry_attempts = retry_attempts
|
|
self.retry_delay = retry_delay
|
|
|
|
# Rate limiting
|
|
self.request_interval = 60.0 / requests_per_minute
|
|
self.last_request_time = 0
|
|
self.request_lock = threading.Lock()
|
|
|
|
def query_batch(
|
|
self,
|
|
inputs: List[Dict[str, Any]],
|
|
callback: Optional[Callable[[int, str], None]] = None
|
|
) -> List[str]:
|
|
"""
|
|
Send multiple queries in parallel using a thread pool.
|
|
|
|
Args:
|
|
inputs: List of input dictionaries, each containing:
|
|
- 'text': Text input for the model
|
|
- 'image': Optional image input (path or PIL Image)
|
|
- Other model-specific parameters
|
|
callback: Optional callback function to call when each result is ready
|
|
Function signature: callback(index, response)
|
|
|
|
Returns:
|
|
List of model responses in the same order as inputs
|
|
"""
|
|
results = [None] * len(inputs)
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
# Submit all jobs
|
|
future_to_index = {
|
|
executor.submit(
|
|
self._thread_safe_query,
|
|
input_dict.get('text', ''),
|
|
input_dict.get('image', None),
|
|
input_dict.get('maintain_history', False)
|
|
): i
|
|
for i, input_dict in enumerate(inputs)
|
|
}
|
|
|
|
# Process results as they complete
|
|
for future in as_completed(future_to_index):
|
|
index = future_to_index[future]
|
|
try:
|
|
response = future.result()
|
|
results[index] = response
|
|
|
|
# Call callback if provided
|
|
if callback:
|
|
callback(index, response)
|
|
|
|
except Exception as e:
|
|
# Record error message as response
|
|
results[index] = f"Error: {str(e)}"
|
|
if callback:
|
|
callback(index, results[index])
|
|
|
|
return results
|
|
|
|
def _thread_safe_query(self, text_input: str = "", image_input: Any = None, maintain_history: bool = False) -> str:
|
|
"""
|
|
Thread-safe wrapper around the query method with rate limiting.
|
|
|
|
Args:
|
|
text_input: The text input for the model
|
|
image_input: Optional image input
|
|
maintain_history: Whether to maintain conversation history
|
|
|
|
Returns:
|
|
The model's response
|
|
"""
|
|
# Apply rate limiting
|
|
with self.request_lock:
|
|
current_time = time.time()
|
|
time_since_last = current_time - self.last_request_time
|
|
|
|
if time_since_last < self.request_interval:
|
|
sleep_time = self.request_interval - time_since_last
|
|
time.sleep(sleep_time)
|
|
|
|
self.last_request_time = time.time()
|
|
|
|
# Make the actual query with retries
|
|
for attempt in range(self.retry_attempts):
|
|
try:
|
|
return self.query(text_input, image_input, maintain_history)
|
|
except Exception as e:
|
|
if attempt < self.retry_attempts - 1:
|
|
time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
|
else:
|
|
raise e # Re-raise the exception on the last attempt
|
|
|
|
def query(self, text_input: str = "", image_input: Any = None, maintain_history: bool = False) -> str:
|
|
"""
|
|
Subclasses must implement this method.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement the query method") |