diff --git a/config/config_loader.py b/config/config_loader.py index 65e8052..1cf04fe 100644 --- a/config/config_loader.py +++ b/config/config_loader.py @@ -188,27 +188,11 @@ class ConfigLoader: attack_names = test_case_cfg["attacks"] attack_params = test_case_cfg.get("attack_params", {}) or {} - # Merge attack configurations - merged_attack_params = {} - for attack_name in attack_names: - try: - attack_config = self.load_attack_config(attack_name) - base_params = attack_config.get("parameters", {}) - - # If there are override parameters in general config, merge them - if attack_name in attack_params: - merged_attack_params[attack_name] = self._deep_merge( - base_params, attack_params[attack_name] - ) - else: - merged_attack_params[attack_name] = base_params - except FileNotFoundError: - print( - f"Warning: Attack configuration file does not exist: {attack_name}" - ) - continue - - test_case_cfg["attack_params"] = merged_attack_params + test_case_cfg["attack_params"] = self._merge_component_configs( + component_names=attack_names, + override_params=attack_params, + config_loader_func=self.load_attack_config, + ) # Process response generation configuration if "response_generation" in full_config: @@ -245,31 +229,12 @@ class ConfigLoader: defense_names = response_cfg["defenses"] defense_overrides = response_cfg.get("defense_params", {}) or {} - # Merge defense configurations - merged_defense_params = {} - for defense_name in defense_names: - if defense_name == "None": - merged_defense_params[defense_name] = {} - continue - - try: - defense_config = self.load_defense_config(defense_name) - base_params = defense_config.get("parameters", {}) - - # If there are override parameters in general config, merge them - if defense_name in defense_overrides: - merged_defense_params[defense_name] = self._deep_merge( - base_params, defense_overrides[defense_name] - ) - else: - merged_defense_params[defense_name] = base_params - except FileNotFoundError: - print( - f"Warning: Defense configuration file does not exist: {defense_name}" - ) - continue - - response_cfg["defense_params"] = merged_defense_params + response_cfg["defense_params"] = self._merge_component_configs( + component_names=defense_names, + override_params=defense_overrides, + config_loader_func=self.load_defense_config, + skip_value="None", + ) # Process evaluation configuration if "evaluation" in full_config: @@ -361,6 +326,50 @@ class ConfigLoader: return result + def _merge_component_configs( + self, + component_names: List[str], + override_params: Dict[str, Any], + config_loader_func, # Callable: takes name, returns config dict + skip_value: str = None, + ) -> Dict[str, Any]: + """Merge component configurations from files with runtime overrides. + + Args: + component_names: List of component names to process + override_params: Runtime override parameters from general_config + config_loader_func: Function to load base config for a component + skip_value: Special value that should be skipped with empty config + + Returns: + Merged configuration dictionary + """ + merged_params = {} + + for component_name in component_names: + if skip_value is not None and component_name == skip_value: + merged_params[component_name] = {} + continue + + try: + base_config = config_loader_func(component_name) + base_params = base_config.get("parameters", {}) + + if component_name in override_params: + merged_params[component_name] = self._deep_merge( + base_params, override_params[component_name] + ) + else: + merged_params[component_name] = base_params + + except FileNotFoundError: + print( + f"Warning: Component configuration file does not exist: {component_name}" + ) + continue + + return merged_params + def _load_yaml_file(self, file_path: Path) -> Dict[str, Any]: """Load YAML file""" with open(file_path, "r", encoding="utf-8") as f: diff --git a/core/base_classes.py b/core/base_classes.py index 28880b3..4fcce00 100644 --- a/core/base_classes.py +++ b/core/base_classes.py @@ -47,7 +47,11 @@ class BaseComponent(ABC): cfg_dict = config allowed_fields = {f.name for f in fields(self.CONFIG_CLASS)} filtered = {k: v for k, v in cfg_dict.items() if k in allowed_fields} - self.cfg = self.CONFIG_CLASS(**filtered) + try: + self.cfg = self.CONFIG_CLASS(**filtered) + except TypeError as e: + # Convert TypeError to ValueError for consistency + raise ValueError(f"Invalid configuration: {e}") from e else: # No configuration class or not a dataclass, use dictionary directly self.cfg = config @@ -77,6 +81,24 @@ class BaseComponent(ABC): pass # Subclasses can add more validation logic + def _determine_load_model(self) -> bool: + """Determine if local model needs to be loaded. + + Checks the load_model field in configuration (dict or dataclass). + Returns False if load_model is not explicitly set. + """ + if hasattr(self, "cfg"): + config_obj = self.cfg + else: + config_obj = self.config + + if isinstance(config_obj, dict): + return config_obj.get("load_model", False) + elif hasattr(config_obj, "load_model"): + return getattr(config_obj, "load_model", False) + + return False + class BaseAttack(BaseComponent, ABC): """Attack method base class (enhanced version)""" @@ -109,25 +131,6 @@ class BaseAttack(BaseComponent, ABC): # Determine if local model needs to be loaded self.load_model = self._determine_load_model() - def _determine_load_model(self) -> bool: - """Determine if local model needs to be loaded - - Only decide based on the load_model field in configuration - """ - # Check if load_model field exists in configuration - if hasattr(self, "cfg"): - config_obj = self.cfg - else: - config_obj = self.config - - # Only check load_model configuration item - if isinstance(config_obj, dict): - return config_obj.get("load_model", False) - elif hasattr(config_obj, "load_model"): - return getattr(config_obj, "load_model", False) - - return False - @abstractmethod def generate_test_case( self, original_prompt: str, image_path: str, case_id: str, **kwargs @@ -259,25 +262,6 @@ class BaseDefense(BaseComponent, ABC): # Determine if local model needs to be loaded self.load_model = self._determine_load_model() - def _determine_load_model(self) -> bool: - """Determine if local model needs to be loaded - - Only based on the load_model field in configuration - """ - # Check if load_model field exists in configuration - if hasattr(self, "cfg"): - config_obj = self.cfg - else: - config_obj = self.config - - # Only check load_model configuration item - if isinstance(config_obj, dict): - return config_obj.get("load_model", False) - elif hasattr(config_obj, "load_model"): - return getattr(config_obj, "load_model", False) - - return False - @abstractmethod def apply_defense(self, test_case: TestCase, **kwargs) -> TestCase: """ diff --git a/core/unified_registry.py b/core/unified_registry.py index 2eb980e..b7b1340 100644 --- a/core/unified_registry.py +++ b/core/unified_registry.py @@ -109,124 +109,79 @@ class UnifiedRegistry: self.evaluator_registry[name] = evaluator_class self.logger.debug(f"Registered evaluator: {name}") - def get_attack(self, name: str) -> Optional[Type["BaseAttack"]]: - """Get attack method class""" - if name in self.attack_registry: - return self.attack_registry[name] + def _get_component( + self, + name: str, + component_type: str, # "attacks", "models", "defenses", "evaluators" + ) -> Optional[Type]: + """Generic component getter with lazy loading. - # Get mapping information from config/plugins.yaml + Args: + name: Component name + component_type: One of "attacks", "models", "defenses", "evaluators" + + Returns: + Component class or None if not found + """ + type_info = { + "attacks": ("attack_registry", "attack method"), + "models": ("model_registry", "model"), + "defenses": ("defense_registry", "defense method"), + "evaluators": ("evaluator_registry", "evaluator"), + } + + if component_type not in type_info: + raise ValueError(f"Invalid component_type: {component_type}") + + registry_attr, type_name = type_info[component_type] + registry = getattr(self, registry_attr) + + # Check cache first + if name in registry: + return registry[name] + + # Try lazy loading from plugins.yaml try: mappings = self._get_lazy_mappings() - if name in mappings["attacks"]: - module_path, class_name = mappings["attacks"][name] + if name in mappings[component_type]: + module_path, class_name = mappings[component_type][name] module = importlib.import_module(module_path) cls = getattr(module, class_name) - # Register to cache - self.attack_registry[name] = cls + # Cache for future access + registry[name] = cls self.logger.debug( - f"Successfully imported attack method from mapping: {name}" + f"Successfully imported {type_name} from mapping: {name}" ) return cls except (ImportError, AttributeError) as e: self.logger.debug( - f"Unable to import attack method '{name}' from mapping: {e}" + f"Unable to import {type_name} '{name}' from mapping: {e}" ) return None except Exception as e: self.logger.error( - f"Unknown error occurred while importing attack method '{name}': {e}" + f"Unknown error occurred while importing {type_name} '{name}': {e}" ) return None - self.logger.warning(f"Attack method '{name}' is not defined in mapping") + self.logger.warning(f"{type_name.capitalize()} '{name}' is not defined in mapping") return None + def get_attack(self, name: str) -> Optional[Type["BaseAttack"]]: + """Get attack method class""" + return self._get_component(name, "attacks") + def get_model(self, name: str) -> Optional[Type["BaseModel"]]: """Get model class""" - if name in self.model_registry: - return self.model_registry[name] - - # Get mapping information from config/plugins.yaml - try: - mappings = self._get_lazy_mappings() - if name in mappings["models"]: - module_path, class_name = mappings["models"][name] - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - self.model_registry[name] = cls - self.logger.debug(f"Successfully imported model from mapping: {name}") - return cls - except (ImportError, AttributeError) as e: - self.logger.debug(f"Unable to import model '{name}' from mapping: {e}") - return None - except Exception as e: - self.logger.error( - f"Unknown error occurred while importing model '{name}': {e}" - ) - return None - - self.logger.warning(f"Model '{name}' is not defined in mapping") - return None + return self._get_component(name, "models") def get_defense(self, name: str) -> Optional[Type["BaseDefense"]]: """Get defense method class""" - if name in self.defense_registry: - return self.defense_registry[name] - - # Get mapping information from config/plugins.yaml - try: - mappings = self._get_lazy_mappings() - if name in mappings["defenses"]: - module_path, class_name = mappings["defenses"][name] - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - self.defense_registry[name] = cls - self.logger.debug( - f"Successfully imported defense method from mapping: {name}" - ) - return cls - except (ImportError, AttributeError) as e: - self.logger.debug( - f"Unable to import defense method '{name}' from mapping: {e}" - ) - return None - except Exception as e: - self.logger.error( - f"Unknown error occurred while importing defense method '{name}': {e}" - ) - return None - - self.logger.warning(f"Defense method '{name}' is not defined in mapping") - return None + return self._get_component(name, "defenses") def get_evaluator(self, name: str) -> Optional[Type["BaseEvaluator"]]: """Get evaluator class""" - if name in self.evaluator_registry: - return self.evaluator_registry[name] - - # Get mapping information from config/plugins.yaml - try: - mappings = self._get_lazy_mappings() - if name in mappings["evaluators"]: - module_path, class_name = mappings["evaluators"][name] - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - self.evaluator_registry[name] = cls - self.logger.debug( - f"Successfully imported evaluator from mapping: {name}" - ) - return cls - except (ImportError, AttributeError) as e: - self.logger.debug(f"Unable to import evaluator '{name}' from mapping: {e}") - return None - except Exception as e: - self.logger.error( - f"Unknown error occurred while importing evaluator '{name}': {e}" - ) - return None - - self.logger.warning(f"Evaluator '{name}' is not defined in mapping") - return None + return self._get_component(name, "evaluators") def create_attack( self, name: str, config: Dict[str, Any] = None, output_image_dir: str = None diff --git a/models/anthropic_model.py b/models/anthropic_model.py index fd76bef..acf5814 100644 --- a/models/anthropic_model.py +++ b/models/anthropic_model.py @@ -7,6 +7,11 @@ class AnthropicModel(BaseModel): default_output = "I'm sorry, but I cannot assist with that request." + # Anthropic-specific content policy keywords + PROVIDER_SPECIFIC_KEYWORDS = [ + "output blocked by content filtering policy", + ] + def __init__(self, model_name: str, api_key: str) -> None: super().__init__(model_name, api_key) @@ -53,15 +58,8 @@ class AnthropicModel(BaseModel): ) return response except Exception as e: - error_str = str(e).lower() - if "Output blocked by content filtering policy" in str(e): - return self.API_CONTENT_REJECTION_OUTPUT - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - return self.API_CONTENT_REJECTION_OUTPUT + if self._is_content_policy_rejection(e): + return self._handle_content_rejection() raise return self._retry_with_backoff(_api_call) @@ -100,23 +98,8 @@ class AnthropicModel(BaseModel): ) return stream except Exception as e: - error_str = str(e).lower() - if "Output blocked by content filtering policy" in str(e): - # Return a generator that yields the content rejection placeholder - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() + if self._is_content_policy_rejection(e): + return self._handle_content_rejection_stream() raise try: @@ -128,4 +111,4 @@ class AnthropicModel(BaseModel): elif hasattr(chunk, "completion") and chunk.completion: yield chunk.completion except Exception: - yield self.API_ERROR_OUTPUT + yield self._handle_api_error() diff --git a/models/base_model.py b/models/base_model.py index d2d95b0..a61f75f 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -18,6 +18,21 @@ class BaseModel(CoreBaseModel): API_MAX_RETRY = 3 API_TIMEOUT = 600 + # Content policy detection keywords (common across providers) + CONTENT_POLICY_KEYWORDS = [ + "content policy", + "safety", + "harmful", + "unsafe", + "violation", + "moderation", + "data_inspection_failed", + "inappropriate content", + ] + + # Provider-specific additional keywords (subclasses can override) + PROVIDER_SPECIFIC_KEYWORDS = [] + def __init__(self, model_name: str, api_key: str = None, base_url: str = None): # Call parent class __init__, pass empty configuration super().__init__(config={}) @@ -76,6 +91,46 @@ class BaseModel(CoreBaseModel): else: return "local" + def _is_content_policy_rejection(self, error: Exception) -> bool: + """Check if an exception represents a content policy rejection. + + Args: + error: The exception to check + + Returns: + True if the error indicates content policy rejection + """ + error_str = str(error).lower() + + # Combine common and provider-specific keywords + all_keywords = self.CONTENT_POLICY_KEYWORDS + self.PROVIDER_SPECIFIC_KEYWORDS + + return any(keyword in error_str for keyword in all_keywords) + + def _handle_content_rejection(self) -> str: + """Return the standard content rejection output.""" + return self.API_CONTENT_REJECTION_OUTPUT + + def _handle_content_rejection_stream(self): + """Return a generator that yields the content rejection placeholder.""" + + def error_generator(): + yield self.API_CONTENT_REJECTION_OUTPUT + + return error_generator() + + def _handle_api_error(self) -> str: + """Return the standard API error output.""" + return self.API_ERROR_OUTPUT + + def _handle_api_error_stream(self): + """Return a generator that yields the API error placeholder.""" + + def error_generator(): + yield self.API_ERROR_OUTPUT + + return error_generator() + def _retry_with_backoff(self, func, *args, **kwargs): """Execute function with retry logic and exponential backoff using backoff library.""" diff --git a/models/doubao_model.py b/models/doubao_model.py index 06d1c86..a87b60b 100644 --- a/models/doubao_model.py +++ b/models/doubao_model.py @@ -33,20 +33,8 @@ class DoubaoModel(BaseModel): ) return response except Exception as e: - # Check for content policy violations in ByteDance models - error_str = str(e).lower() - content_keywords = [ - "content policy", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "data_inspection_failed", - "inappropriate content", - ] - if any(keyword in error_str for keyword in content_keywords): - return self.API_CONTENT_REJECTION_OUTPUT + if self._is_content_policy_rejection(e): + return self._handle_content_rejection() raise return self._retry_with_backoff(_api_call) @@ -68,34 +56,8 @@ class DoubaoModel(BaseModel): ) return stream except Exception as e: - # Check for content policy violations in ByteDance models - error_str = str(e).lower() - content_keywords = [ - "content policy", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "data_inspection_failed", - "inappropriate content", - ] - if any(keyword in error_str for keyword in content_keywords): - # Return a generator that yields the content rejection placeholder - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() + if self._is_content_policy_rejection(e): + return self._handle_content_rejection_stream() raise try: @@ -104,4 +66,4 @@ class DoubaoModel(BaseModel): if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content except Exception: - yield self.API_ERROR_OUTPUT + yield self._handle_api_error() diff --git a/models/google_model.py b/models/google_model.py index 41f5004..b7fa847 100644 --- a/models/google_model.py +++ b/models/google_model.py @@ -7,6 +7,11 @@ class GoogleModel(BaseModel): default_output = "I'm sorry, but I cannot assist with that request." + # Google-specific content policy keywords + PROVIDER_SPECIFIC_KEYWORDS = [ + "blocked", + ] + def __init__(self, model_name: str, api_key: str, base_url: str = None) -> None: super().__init__(model_name=model_name, api_key=api_key, base_url=base_url) @@ -37,27 +42,8 @@ class GoogleModel(BaseModel): ) return response except Exception as e: - # Check for content policy violations in Gemini models - error_str = str(e).lower() - content_keywords = [ - "content policy", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "blocked", - "data_inspection_failed", - "inappropriate content", - ] - if any(keyword in error_str for keyword in content_keywords): - return self.API_CONTENT_REJECTION_OUTPUT - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - return self.API_CONTENT_REJECTION_OUTPUT + if self._is_content_policy_rejection(e): + return self._handle_content_rejection() raise return self._retry_with_backoff(_api_call) @@ -86,4 +72,4 @@ class GoogleModel(BaseModel): if chunk.text: yield chunk.text except Exception: - yield self.API_ERROR_OUTPUT + yield self._handle_api_error() diff --git a/models/mistral_model.py b/models/mistral_model.py index 5c84a2c..3b84d60 100644 --- a/models/mistral_model.py +++ b/models/mistral_model.py @@ -47,26 +47,8 @@ class MistralModel(BaseModel): ) return chat_response except Exception as e: - # Check for content policy violations in Mistral models - error_str = str(e).lower() - content_keywords = [ - "content policy", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "data_inspection_failed", - "inappropriate content", - ] - if any(keyword in error_str for keyword in content_keywords): - return self.API_CONTENT_REJECTION_OUTPUT - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - return self.API_CONTENT_REJECTION_OUTPUT + if self._is_content_policy_rejection(e): + return self._handle_content_rejection() raise return self._retry_with_backoff(_api_call) @@ -99,34 +81,8 @@ class MistralModel(BaseModel): ) return stream except Exception as e: - # Check for content policy violations in Mistral models - error_str = str(e).lower() - content_keywords = [ - "content policy", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "data_inspection_failed", - "inappropriate content", - ] - if any(keyword in error_str for keyword in content_keywords): - # Return a generator that yields the content rejection placeholder - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() + if self._is_content_policy_rejection(e): + return self._handle_content_rejection_stream() raise try: @@ -136,4 +92,4 @@ class MistralModel(BaseModel): if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content except Exception: - yield self.API_ERROR_OUTPUT + yield self._handle_api_error() diff --git a/models/openai_model.py b/models/openai_model.py index f23a7c0..21de5c1 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -5,6 +5,14 @@ from .base_model import BaseModel class OpenAIModel(BaseModel): """OpenAI model implementation using OpenAI API.""" + # OpenAI-specific content policy keywords + PROVIDER_SPECIFIC_KEYWORDS = [ + "invalid", + "inappropriate", + "invalid_prompt", + "limited access", + ] + def __init__(self, model_name: str, api_key: str, base_url: Optional[str] = None): super().__init__(model_name=model_name, api_key=api_key, base_url=base_url) @@ -41,27 +49,11 @@ class OpenAIModel(BaseModel): return response except Exception as e: - # Check for content policy violations in GPT models - error_str = str(e).lower() - print("Error during API call:", error_str) - content_keywords = [ - "content policy", - "invalid", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "data_inspection_failed", - "inappropriate", - "invalid_prompt", - "limited access", - ] - if any(keyword in error_str for keyword in content_keywords): - print("✓ Content rejection triggered") - return self.API_CONTENT_REJECTION_OUTPUT - print("✗ No content keywords matched, raising exception") - raise e + print("Error during API call:", str(e).lower()) + if self._is_content_policy_rejection(e): + print("Content rejection triggered") + return self._handle_content_rejection() + raise return self._retry_with_backoff(_api_call) @@ -86,34 +78,8 @@ class OpenAIModel(BaseModel): ) return stream except Exception as e: - # Check for content policy violations in GPT models - error_str = str(e).lower() - content_keywords = [ - "content policy", - "safety", - "harmful", - "unsafe", - "violation", - "moderation", - "data_inspection_failed", - "inappropriate content", - ] - if any(keyword in error_str for keyword in content_keywords): - # Return a generator that yields the content rejection placeholder - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() - # Handle BadRequestError specifically - if ( - "badrequesterror" in error_str - and "data_inspection_failed" in error_str - ): - - def error_generator(): - yield self.API_CONTENT_REJECTION_OUTPUT - - return error_generator() + if self._is_content_policy_rejection(e): + return self._handle_content_rejection_stream() raise try: @@ -122,4 +88,4 @@ class OpenAIModel(BaseModel): if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content except Exception: - yield self.API_ERROR_OUTPUT + yield self._handle_api_error() diff --git a/pipeline/base_pipeline.py b/pipeline/base_pipeline.py index a3f9531..9633ae2 100644 --- a/pipeline/base_pipeline.py +++ b/pipeline/base_pipeline.py @@ -465,8 +465,19 @@ class BasePipeline(ABC): return str(filepath) def save_single_result(self, result: Dict, filename: str) -> str: - """Save single result to file (append mode)""" - filepath = self.output_dir / filename + """Save single result to file (append mode) + + Args: + result: The result to save + filename: File path (can be absolute or relative to output_dir) + + Returns: + str: Path to the saved file + """ + # Handle both absolute and relative paths + filepath = Path(filename) + if not filepath.is_absolute(): + filepath = self.output_dir / filename # Atomic write: write to temporary file first temp_file = filepath.with_suffix(".tmp") @@ -590,7 +601,7 @@ class BatchSaveManager: # Check if save is needed if len(self.buffer) >= self.batch_size: - self._flush_buffer() + self._flush_unlocked() def add_results(self, results: List[Dict]) -> None: """Batch add results""" @@ -604,14 +615,21 @@ class BatchSaveManager: self._save_batch(batch) def _flush_buffer(self) -> None: - """Save all results in buffer""" + """Save all results in buffer (thread-safe public method)""" + with self.lock: + self._flush_unlocked() + + def _flush_unlocked(self) -> None: + """Internal flush method, caller must hold lock. + + This is the actual implementation that flushes the buffer. + It must only be called while holding self.lock. + """ if not self.buffer: return - - with self.lock: - batch = self.buffer.copy() - self.buffer.clear() - self._save_batch(batch) + batch = self.buffer.copy() + self.buffer.clear() + self._save_batch(batch) def _save_batch(self, batch: List[Dict]) -> None: """Save a batch of results""" diff --git a/tests/test_core.py b/tests/test_core.py index 1e4038c..59e7bb9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -21,20 +21,17 @@ class TestBaseClasses: test_case_id="test", prompt="test", image_path="test", metadata={} ) - # Test case requiring local model loading - attack_with_model = TestAttack(config={"model_path": "/path/to/model"}) + # Test with explicit load_model=True + attack_with_model = TestAttack(config={"load_model": True}) assert attack_with_model.load_model is True - # Test case not requiring local model loading - attack_without_model = TestAttack(config={}) + # Test with explicit load_model=False + attack_without_model = TestAttack(config={"load_model": False}) assert attack_without_model.load_model is False - # Test other local model configuration items - attack_with_device = TestAttack(config={"device": "cuda:0"}) - assert attack_with_device.load_model is True - - attack_with_checkpoint = TestAttack(config={"checkpoint": "model.ckpt"}) - assert attack_with_checkpoint.load_model is True + # Test without load_model field (defaults to False) + attack_default = TestAttack(config={}) + assert attack_default.load_model is False def test_base_component_required_field_validation(self): """Test required field validation (error when missing, normal when provided)""" @@ -73,14 +70,18 @@ class TestBaseClasses: def apply_defense(self, test_case, **kwargs): return test_case - # Test case requiring local model loading - defense_with_model = TestDefense(config={"local_model": True}) + # Test with explicit load_model=True + defense_with_model = TestDefense(config={"load_model": True}) assert defense_with_model.load_model is True - # Test case not requiring local model loading - defense_without_model = TestDefense(config={}) + # Test with explicit load_model=False + defense_without_model = TestDefense(config={"load_model": False}) assert defense_without_model.load_model is False + # Test without load_model field (defaults to False) + defense_default = TestDefense(config={}) + assert defense_default.load_model is False + def test_base_model_model_type_detection(self): """Test model type detection""" from models.base_model import BaseModel diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d508833..2c26927 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -5,6 +5,8 @@ Test pipeline system import pytest import tempfile import json +import threading +import time from pathlib import Path from unittest.mock import Mock, patch, MagicMock @@ -133,22 +135,26 @@ class TestBasePipeline: pipeline_config = PipelineConfig(**test_config) pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation") + # Create a temporary file just to get a unique name with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - temp_file = Path(f.name) + temp_file_name = Path(f.name).name + + # The actual file will be created in pipeline's output_dir + output_file = pipeline.output_dir / temp_file_name try: first = {"test_case_id": "case_1", "data": 1} second = {"test_case_id": "case_1", "data": 2} - pipeline.save_single_result(first, temp_file.name) - pipeline.save_single_result(second, temp_file.name) + pipeline.save_single_result(first, temp_file_name) + pipeline.save_single_result(second, temp_file_name) - loaded = pipeline.load_results(temp_file) + loaded = pipeline.load_results(output_file) assert len(loaded) == 1 assert loaded[0]["data"] == 2 finally: - if temp_file.exists(): - temp_file.unlink() + if output_file.exists(): + output_file.unlink() def test_task_hash(self, test_config): """Test task hash generation""" @@ -232,6 +238,9 @@ class TestBatchSaveManager: # Verify buffer has been cleared assert len(manager.buffer) == 0 assert manager.total_saved == 2 + finally: + if temp_file.exists(): + temp_file.unlink() def test_batch_save_flush(self, test_config): """Test flush saves remaining buffer""" @@ -261,10 +270,6 @@ class TestBatchSaveManager: if temp_file.exists(): temp_file.unlink() - finally: - if temp_file.exists(): - temp_file.unlink() - def test_batch_save_context_manager(self, test_config): """Test batch save context manager""" from pipeline.base_pipeline import batch_save_context @@ -298,6 +303,110 @@ class TestBatchSaveManager: if temp_file.exists(): temp_file.unlink() + def test_batch_save_no_deadlock_when_exceeding_batch_size(self, test_config): + """Test that adding results equal to batch_size doesn't cause deadlock. + + This test verifies the fix for the deadlock bug where: + - add_result() acquires self.lock + - When buffer >= batch_size, _flush_buffer() is called + - _flush_buffer() tries to acquire self.lock again → DEADLOCK + + The fix uses _flush_unlocked() which assumes caller holds the lock. + """ + from pipeline.base_pipeline import BatchSaveManager + from core.data_formats import PipelineConfig + import tempfile + + pipeline_config = PipelineConfig(**test_config) + pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_file = Path(f.name) + + try: + manager = BatchSaveManager( + pipeline=pipeline, output_filename=temp_file, batch_size=3 + ) + + # Add exactly batch_size results - this would trigger the deadlock in the bug + results = [{"id": i, "data": f"result_{i}"} for i in range(3)] + + # Use a timeout to catch potential deadlocks + def add_results_with_timeout(): + for result in results: + manager.add_result(result) + + thread = threading.Thread(target=add_results_with_timeout) + thread.start() + + # Wait up to 5 seconds - if deadlock occurs, this will timeout + thread.join(timeout=5.0) + + assert not thread.is_alive(), "Thread is still alive - likely deadlock!" + assert manager.total_saved == 3, f"Expected 3 saved, got {manager.total_saved}" + assert len(manager.buffer) == 0, "Buffer should be empty after batch save" + + finally: + if temp_file.exists(): + temp_file.unlink() + + def test_batch_save_concurrent_additions(self, test_config): + """Test thread-safe batch saving with concurrent additions""" + from pipeline.base_pipeline import BatchSaveManager + from core.data_formats import PipelineConfig + import tempfile + + pipeline_config = PipelineConfig(**test_config) + pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_file = Path(f.name) + + try: + manager = BatchSaveManager( + pipeline=pipeline, output_filename=temp_file, batch_size=10 + ) + + num_threads = 4 + results_per_thread = 25 + exceptions = [] + + def add_results(thread_id): + try: + for i in range(results_per_thread): + manager.add_result( + {"thread_id": thread_id, "index": i, "data": f"t{thread_id}_r{i}"} + ) + # Small random delay to increase contention + time.sleep(0.001) + except Exception as e: + exceptions.append((thread_id, e)) + + threads = [] + for i in range(num_threads): + t = threading.Thread(target=add_results, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads with timeout + for t in threads: + t.join(timeout=30.0) + + # Check no thread is stuck (deadlock) + assert not any(t.is_alive() for t in threads), "Some threads are still alive - likely deadlock!" + + # Check no exceptions occurred + assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}" + + # Verify all results were saved + expected_total = num_threads * results_per_thread + assert manager.total_saved >= expected_total - manager.batch_size, \ + f"Expected at least {expected_total - manager.batch_size} saved, got {manager.total_saved}" + + finally: + if temp_file.exists(): + temp_file.unlink() + class TestParallelProcessing: """Test parallel processing"""