# jailbreak_toolbox/models/implementations/multithreaded_model.py import threading import queue import time from typing import List, Dict, Any, Optional, Callable from concurrent.futures import ThreadPoolExecutor, as_completed from .base_model import BaseModel from ..core.registry import model_registry @model_registry.register("multithreaded_model") class MultiThreadedModel(BaseModel): """ Base class for models that support multi-threaded API calls. This provides an efficient way to make multiple API calls in parallel, with rate limiting to avoid API throttling. """ def __init__( self, max_workers: int = 5, requests_per_minute: int = 60, retry_attempts: int = 3, retry_delay: float = 1.0, **kwargs ): """ Initialize the multi-threaded model. Args: max_workers: Maximum number of worker threads requests_per_minute: Maximum number of requests per minute retry_attempts: Number of retry attempts on failure retry_delay: Delay between retries in seconds """ super().__init__(**kwargs) self.max_workers = max_workers self.requests_per_minute = requests_per_minute self.retry_attempts = retry_attempts self.retry_delay = retry_delay # Rate limiting self.request_interval = 60.0 / requests_per_minute self.last_request_time = 0 self.request_lock = threading.Lock() def query_batch( self, inputs: List[Dict[str, Any]], callback: Optional[Callable[[int, str], None]] = None ) -> List[str]: """ Send multiple queries in parallel using a thread pool. Args: inputs: List of input dictionaries, each containing: - 'text': Text input for the model - 'image': Optional image input (path or PIL Image) - Other model-specific parameters callback: Optional callback function to call when each result is ready Function signature: callback(index, response) Returns: List of model responses in the same order as inputs """ results = [None] * len(inputs) with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Submit all jobs future_to_index = { executor.submit( self._thread_safe_query, input_dict.get('text', ''), input_dict.get('image', None), input_dict.get('maintain_history', False) ): i for i, input_dict in enumerate(inputs) } # Process results as they complete for future in as_completed(future_to_index): index = future_to_index[future] try: response = future.result() results[index] = response # Call callback if provided if callback: callback(index, response) except Exception as e: # Record error message as response results[index] = f"Error: {str(e)}" if callback: callback(index, results[index]) return results def _thread_safe_query(self, text_input: str = "", image_input: Any = None, maintain_history: bool = False) -> str: """ Thread-safe wrapper around the query method with rate limiting. Args: text_input: The text input for the model image_input: Optional image input maintain_history: Whether to maintain conversation history Returns: The model's response """ # Apply rate limiting with self.request_lock: current_time = time.time() time_since_last = current_time - self.last_request_time if time_since_last < self.request_interval: sleep_time = self.request_interval - time_since_last time.sleep(sleep_time) self.last_request_time = time.time() # Make the actual query with retries for attempt in range(self.retry_attempts): try: return self.query(text_input, image_input, maintain_history) except Exception as e: if attempt < self.retry_attempts - 1: time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff else: raise e # Re-raise the exception on the last attempt def query(self, text_input: str = "", image_input: Any = None, maintain_history: bool = False) -> str: """ Subclasses must implement this method. """ raise NotImplementedError("Subclasses must implement the query method")