refactor: consolidate duplicate code and fix deadlock bug

- Extract duplicate attack/defense config merging into _merge_component_configs()
- Extract duplicate lazy loading logic into _get_component()
- Move content policy detection to BaseModel base class
- Fix BatchSaveManager deadlock by splitting flush logic
- Add TypeError to ValueError conversion for consistent config errors
- Move _determine_load_model() to BaseComponent (explicit field only)
This commit is contained in:
Liao, Jie
2026-01-23 12:55:34 +08:00
parent 84a4d1708e
commit 04a1cbe8d1
12 changed files with 384 additions and 400 deletions

View File

@@ -188,27 +188,11 @@ class ConfigLoader:
attack_names = test_case_cfg["attacks"]
attack_params = test_case_cfg.get("attack_params", {}) or {}
# Merge attack configurations
merged_attack_params = {}
for attack_name in attack_names:
try:
attack_config = self.load_attack_config(attack_name)
base_params = attack_config.get("parameters", {})
# If there are override parameters in general config, merge them
if attack_name in attack_params:
merged_attack_params[attack_name] = self._deep_merge(
base_params, attack_params[attack_name]
test_case_cfg["attack_params"] = self._merge_component_configs(
component_names=attack_names,
override_params=attack_params,
config_loader_func=self.load_attack_config,
)
else:
merged_attack_params[attack_name] = base_params
except FileNotFoundError:
print(
f"Warning: Attack configuration file does not exist: {attack_name}"
)
continue
test_case_cfg["attack_params"] = merged_attack_params
# Process response generation configuration
if "response_generation" in full_config:
@@ -245,31 +229,12 @@ class ConfigLoader:
defense_names = response_cfg["defenses"]
defense_overrides = response_cfg.get("defense_params", {}) or {}
# Merge defense configurations
merged_defense_params = {}
for defense_name in defense_names:
if defense_name == "None":
merged_defense_params[defense_name] = {}
continue
try:
defense_config = self.load_defense_config(defense_name)
base_params = defense_config.get("parameters", {})
# If there are override parameters in general config, merge them
if defense_name in defense_overrides:
merged_defense_params[defense_name] = self._deep_merge(
base_params, defense_overrides[defense_name]
response_cfg["defense_params"] = self._merge_component_configs(
component_names=defense_names,
override_params=defense_overrides,
config_loader_func=self.load_defense_config,
skip_value="None",
)
else:
merged_defense_params[defense_name] = base_params
except FileNotFoundError:
print(
f"Warning: Defense configuration file does not exist: {defense_name}"
)
continue
response_cfg["defense_params"] = merged_defense_params
# Process evaluation configuration
if "evaluation" in full_config:
@@ -361,6 +326,50 @@ class ConfigLoader:
return result
def _merge_component_configs(
self,
component_names: List[str],
override_params: Dict[str, Any],
config_loader_func, # Callable: takes name, returns config dict
skip_value: str = None,
) -> Dict[str, Any]:
"""Merge component configurations from files with runtime overrides.
Args:
component_names: List of component names to process
override_params: Runtime override parameters from general_config
config_loader_func: Function to load base config for a component
skip_value: Special value that should be skipped with empty config
Returns:
Merged configuration dictionary
"""
merged_params = {}
for component_name in component_names:
if skip_value is not None and component_name == skip_value:
merged_params[component_name] = {}
continue
try:
base_config = config_loader_func(component_name)
base_params = base_config.get("parameters", {})
if component_name in override_params:
merged_params[component_name] = self._deep_merge(
base_params, override_params[component_name]
)
else:
merged_params[component_name] = base_params
except FileNotFoundError:
print(
f"Warning: Component configuration file does not exist: {component_name}"
)
continue
return merged_params
def _load_yaml_file(self, file_path: Path) -> Dict[str, Any]:
"""Load YAML file"""
with open(file_path, "r", encoding="utf-8") as f:

View File

@@ -47,7 +47,11 @@ class BaseComponent(ABC):
cfg_dict = config
allowed_fields = {f.name for f in fields(self.CONFIG_CLASS)}
filtered = {k: v for k, v in cfg_dict.items() if k in allowed_fields}
try:
self.cfg = self.CONFIG_CLASS(**filtered)
except TypeError as e:
# Convert TypeError to ValueError for consistency
raise ValueError(f"Invalid configuration: {e}") from e
else:
# No configuration class or not a dataclass, use dictionary directly
self.cfg = config
@@ -77,6 +81,24 @@ class BaseComponent(ABC):
pass
# Subclasses can add more validation logic
def _determine_load_model(self) -> bool:
"""Determine if local model needs to be loaded.
Checks the load_model field in configuration (dict or dataclass).
Returns False if load_model is not explicitly set.
"""
if hasattr(self, "cfg"):
config_obj = self.cfg
else:
config_obj = self.config
if isinstance(config_obj, dict):
return config_obj.get("load_model", False)
elif hasattr(config_obj, "load_model"):
return getattr(config_obj, "load_model", False)
return False
class BaseAttack(BaseComponent, ABC):
"""Attack method base class (enhanced version)"""
@@ -109,25 +131,6 @@ class BaseAttack(BaseComponent, ABC):
# Determine if local model needs to be loaded
self.load_model = self._determine_load_model()
def _determine_load_model(self) -> bool:
"""Determine if local model needs to be loaded
Only decide based on the load_model field in configuration
"""
# Check if load_model field exists in configuration
if hasattr(self, "cfg"):
config_obj = self.cfg
else:
config_obj = self.config
# Only check load_model configuration item
if isinstance(config_obj, dict):
return config_obj.get("load_model", False)
elif hasattr(config_obj, "load_model"):
return getattr(config_obj, "load_model", False)
return False
@abstractmethod
def generate_test_case(
self, original_prompt: str, image_path: str, case_id: str, **kwargs
@@ -259,25 +262,6 @@ class BaseDefense(BaseComponent, ABC):
# Determine if local model needs to be loaded
self.load_model = self._determine_load_model()
def _determine_load_model(self) -> bool:
"""Determine if local model needs to be loaded
Only based on the load_model field in configuration
"""
# Check if load_model field exists in configuration
if hasattr(self, "cfg"):
config_obj = self.cfg
else:
config_obj = self.config
# Only check load_model configuration item
if isinstance(config_obj, dict):
return config_obj.get("load_model", False)
elif hasattr(config_obj, "load_model"):
return getattr(config_obj, "load_model", False)
return False
@abstractmethod
def apply_defense(self, test_case: TestCase, **kwargs) -> TestCase:
"""

View File

@@ -109,124 +109,79 @@ class UnifiedRegistry:
self.evaluator_registry[name] = evaluator_class
self.logger.debug(f"Registered evaluator: {name}")
def get_attack(self, name: str) -> Optional[Type["BaseAttack"]]:
"""Get attack method class"""
if name in self.attack_registry:
return self.attack_registry[name]
def _get_component(
self,
name: str,
component_type: str, # "attacks", "models", "defenses", "evaluators"
) -> Optional[Type]:
"""Generic component getter with lazy loading.
# Get mapping information from config/plugins.yaml
Args:
name: Component name
component_type: One of "attacks", "models", "defenses", "evaluators"
Returns:
Component class or None if not found
"""
type_info = {
"attacks": ("attack_registry", "attack method"),
"models": ("model_registry", "model"),
"defenses": ("defense_registry", "defense method"),
"evaluators": ("evaluator_registry", "evaluator"),
}
if component_type not in type_info:
raise ValueError(f"Invalid component_type: {component_type}")
registry_attr, type_name = type_info[component_type]
registry = getattr(self, registry_attr)
# Check cache first
if name in registry:
return registry[name]
# Try lazy loading from plugins.yaml
try:
mappings = self._get_lazy_mappings()
if name in mappings["attacks"]:
module_path, class_name = mappings["attacks"][name]
if name in mappings[component_type]:
module_path, class_name = mappings[component_type][name]
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
# Register to cache
self.attack_registry[name] = cls
# Cache for future access
registry[name] = cls
self.logger.debug(
f"Successfully imported attack method from mapping: {name}"
f"Successfully imported {type_name} from mapping: {name}"
)
return cls
except (ImportError, AttributeError) as e:
self.logger.debug(
f"Unable to import attack method '{name}' from mapping: {e}"
f"Unable to import {type_name} '{name}' from mapping: {e}"
)
return None
except Exception as e:
self.logger.error(
f"Unknown error occurred while importing attack method '{name}': {e}"
f"Unknown error occurred while importing {type_name} '{name}': {e}"
)
return None
self.logger.warning(f"Attack method '{name}' is not defined in mapping")
self.logger.warning(f"{type_name.capitalize()} '{name}' is not defined in mapping")
return None
def get_attack(self, name: str) -> Optional[Type["BaseAttack"]]:
"""Get attack method class"""
return self._get_component(name, "attacks")
def get_model(self, name: str) -> Optional[Type["BaseModel"]]:
"""Get model class"""
if name in self.model_registry:
return self.model_registry[name]
# Get mapping information from config/plugins.yaml
try:
mappings = self._get_lazy_mappings()
if name in mappings["models"]:
module_path, class_name = mappings["models"][name]
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
self.model_registry[name] = cls
self.logger.debug(f"Successfully imported model from mapping: {name}")
return cls
except (ImportError, AttributeError) as e:
self.logger.debug(f"Unable to import model '{name}' from mapping: {e}")
return None
except Exception as e:
self.logger.error(
f"Unknown error occurred while importing model '{name}': {e}"
)
return None
self.logger.warning(f"Model '{name}' is not defined in mapping")
return None
return self._get_component(name, "models")
def get_defense(self, name: str) -> Optional[Type["BaseDefense"]]:
"""Get defense method class"""
if name in self.defense_registry:
return self.defense_registry[name]
# Get mapping information from config/plugins.yaml
try:
mappings = self._get_lazy_mappings()
if name in mappings["defenses"]:
module_path, class_name = mappings["defenses"][name]
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
self.defense_registry[name] = cls
self.logger.debug(
f"Successfully imported defense method from mapping: {name}"
)
return cls
except (ImportError, AttributeError) as e:
self.logger.debug(
f"Unable to import defense method '{name}' from mapping: {e}"
)
return None
except Exception as e:
self.logger.error(
f"Unknown error occurred while importing defense method '{name}': {e}"
)
return None
self.logger.warning(f"Defense method '{name}' is not defined in mapping")
return None
return self._get_component(name, "defenses")
def get_evaluator(self, name: str) -> Optional[Type["BaseEvaluator"]]:
"""Get evaluator class"""
if name in self.evaluator_registry:
return self.evaluator_registry[name]
# Get mapping information from config/plugins.yaml
try:
mappings = self._get_lazy_mappings()
if name in mappings["evaluators"]:
module_path, class_name = mappings["evaluators"][name]
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
self.evaluator_registry[name] = cls
self.logger.debug(
f"Successfully imported evaluator from mapping: {name}"
)
return cls
except (ImportError, AttributeError) as e:
self.logger.debug(f"Unable to import evaluator '{name}' from mapping: {e}")
return None
except Exception as e:
self.logger.error(
f"Unknown error occurred while importing evaluator '{name}': {e}"
)
return None
self.logger.warning(f"Evaluator '{name}' is not defined in mapping")
return None
return self._get_component(name, "evaluators")
def create_attack(
self, name: str, config: Dict[str, Any] = None, output_image_dir: str = None

View File

@@ -7,6 +7,11 @@ class AnthropicModel(BaseModel):
default_output = "I'm sorry, but I cannot assist with that request."
# Anthropic-specific content policy keywords
PROVIDER_SPECIFIC_KEYWORDS = [
"output blocked by content filtering policy",
]
def __init__(self, model_name: str, api_key: str) -> None:
super().__init__(model_name, api_key)
@@ -53,15 +58,8 @@ class AnthropicModel(BaseModel):
)
return response
except Exception as e:
error_str = str(e).lower()
if "Output blocked by content filtering policy" in str(e):
return self.API_CONTENT_REJECTION_OUTPUT
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
return self.API_CONTENT_REJECTION_OUTPUT
if self._is_content_policy_rejection(e):
return self._handle_content_rejection()
raise
return self._retry_with_backoff(_api_call)
@@ -100,23 +98,8 @@ class AnthropicModel(BaseModel):
)
return stream
except Exception as e:
error_str = str(e).lower()
if "Output blocked by content filtering policy" in str(e):
# Return a generator that yields the content rejection placeholder
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
if self._is_content_policy_rejection(e):
return self._handle_content_rejection_stream()
raise
try:
@@ -128,4 +111,4 @@ class AnthropicModel(BaseModel):
elif hasattr(chunk, "completion") and chunk.completion:
yield chunk.completion
except Exception:
yield self.API_ERROR_OUTPUT
yield self._handle_api_error()

View File

@@ -18,6 +18,21 @@ class BaseModel(CoreBaseModel):
API_MAX_RETRY = 3
API_TIMEOUT = 600
# Content policy detection keywords (common across providers)
CONTENT_POLICY_KEYWORDS = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
# Provider-specific additional keywords (subclasses can override)
PROVIDER_SPECIFIC_KEYWORDS = []
def __init__(self, model_name: str, api_key: str = None, base_url: str = None):
# Call parent class __init__, pass empty configuration
super().__init__(config={})
@@ -76,6 +91,46 @@ class BaseModel(CoreBaseModel):
else:
return "local"
def _is_content_policy_rejection(self, error: Exception) -> bool:
"""Check if an exception represents a content policy rejection.
Args:
error: The exception to check
Returns:
True if the error indicates content policy rejection
"""
error_str = str(error).lower()
# Combine common and provider-specific keywords
all_keywords = self.CONTENT_POLICY_KEYWORDS + self.PROVIDER_SPECIFIC_KEYWORDS
return any(keyword in error_str for keyword in all_keywords)
def _handle_content_rejection(self) -> str:
"""Return the standard content rejection output."""
return self.API_CONTENT_REJECTION_OUTPUT
def _handle_content_rejection_stream(self):
"""Return a generator that yields the content rejection placeholder."""
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
def _handle_api_error(self) -> str:
"""Return the standard API error output."""
return self.API_ERROR_OUTPUT
def _handle_api_error_stream(self):
"""Return a generator that yields the API error placeholder."""
def error_generator():
yield self.API_ERROR_OUTPUT
return error_generator()
def _retry_with_backoff(self, func, *args, **kwargs):
"""Execute function with retry logic and exponential backoff using backoff library."""

View File

@@ -33,20 +33,8 @@ class DoubaoModel(BaseModel):
)
return response
except Exception as e:
# Check for content policy violations in ByteDance models
error_str = str(e).lower()
content_keywords = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
if any(keyword in error_str for keyword in content_keywords):
return self.API_CONTENT_REJECTION_OUTPUT
if self._is_content_policy_rejection(e):
return self._handle_content_rejection()
raise
return self._retry_with_backoff(_api_call)
@@ -68,34 +56,8 @@ class DoubaoModel(BaseModel):
)
return stream
except Exception as e:
# Check for content policy violations in ByteDance models
error_str = str(e).lower()
content_keywords = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
if any(keyword in error_str for keyword in content_keywords):
# Return a generator that yields the content rejection placeholder
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
if self._is_content_policy_rejection(e):
return self._handle_content_rejection_stream()
raise
try:
@@ -104,4 +66,4 @@ class DoubaoModel(BaseModel):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception:
yield self.API_ERROR_OUTPUT
yield self._handle_api_error()

View File

@@ -7,6 +7,11 @@ class GoogleModel(BaseModel):
default_output = "I'm sorry, but I cannot assist with that request."
# Google-specific content policy keywords
PROVIDER_SPECIFIC_KEYWORDS = [
"blocked",
]
def __init__(self, model_name: str, api_key: str, base_url: str = None) -> None:
super().__init__(model_name=model_name, api_key=api_key, base_url=base_url)
@@ -37,27 +42,8 @@ class GoogleModel(BaseModel):
)
return response
except Exception as e:
# Check for content policy violations in Gemini models
error_str = str(e).lower()
content_keywords = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"blocked",
"data_inspection_failed",
"inappropriate content",
]
if any(keyword in error_str for keyword in content_keywords):
return self.API_CONTENT_REJECTION_OUTPUT
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
return self.API_CONTENT_REJECTION_OUTPUT
if self._is_content_policy_rejection(e):
return self._handle_content_rejection()
raise
return self._retry_with_backoff(_api_call)
@@ -86,4 +72,4 @@ class GoogleModel(BaseModel):
if chunk.text:
yield chunk.text
except Exception:
yield self.API_ERROR_OUTPUT
yield self._handle_api_error()

View File

@@ -47,26 +47,8 @@ class MistralModel(BaseModel):
)
return chat_response
except Exception as e:
# Check for content policy violations in Mistral models
error_str = str(e).lower()
content_keywords = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
if any(keyword in error_str for keyword in content_keywords):
return self.API_CONTENT_REJECTION_OUTPUT
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
return self.API_CONTENT_REJECTION_OUTPUT
if self._is_content_policy_rejection(e):
return self._handle_content_rejection()
raise
return self._retry_with_backoff(_api_call)
@@ -99,34 +81,8 @@ class MistralModel(BaseModel):
)
return stream
except Exception as e:
# Check for content policy violations in Mistral models
error_str = str(e).lower()
content_keywords = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
if any(keyword in error_str for keyword in content_keywords):
# Return a generator that yields the content rejection placeholder
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
if self._is_content_policy_rejection(e):
return self._handle_content_rejection_stream()
raise
try:
@@ -136,4 +92,4 @@ class MistralModel(BaseModel):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception:
yield self.API_ERROR_OUTPUT
yield self._handle_api_error()

View File

@@ -5,6 +5,14 @@ from .base_model import BaseModel
class OpenAIModel(BaseModel):
"""OpenAI model implementation using OpenAI API."""
# OpenAI-specific content policy keywords
PROVIDER_SPECIFIC_KEYWORDS = [
"invalid",
"inappropriate",
"invalid_prompt",
"limited access",
]
def __init__(self, model_name: str, api_key: str, base_url: Optional[str] = None):
super().__init__(model_name=model_name, api_key=api_key, base_url=base_url)
@@ -41,27 +49,11 @@ class OpenAIModel(BaseModel):
return response
except Exception as e:
# Check for content policy violations in GPT models
error_str = str(e).lower()
print("Error during API call:", error_str)
content_keywords = [
"content policy",
"invalid",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate",
"invalid_prompt",
"limited access",
]
if any(keyword in error_str for keyword in content_keywords):
print("✓ Content rejection triggered")
return self.API_CONTENT_REJECTION_OUTPUT
print("✗ No content keywords matched, raising exception")
raise e
print("Error during API call:", str(e).lower())
if self._is_content_policy_rejection(e):
print("Content rejection triggered")
return self._handle_content_rejection()
raise
return self._retry_with_backoff(_api_call)
@@ -86,34 +78,8 @@ class OpenAIModel(BaseModel):
)
return stream
except Exception as e:
# Check for content policy violations in GPT models
error_str = str(e).lower()
content_keywords = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
if any(keyword in error_str for keyword in content_keywords):
# Return a generator that yields the content rejection placeholder
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
# Handle BadRequestError specifically
if (
"badrequesterror" in error_str
and "data_inspection_failed" in error_str
):
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
if self._is_content_policy_rejection(e):
return self._handle_content_rejection_stream()
raise
try:
@@ -122,4 +88,4 @@ class OpenAIModel(BaseModel):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception:
yield self.API_ERROR_OUTPUT
yield self._handle_api_error()

View File

@@ -465,7 +465,18 @@ class BasePipeline(ABC):
return str(filepath)
def save_single_result(self, result: Dict, filename: str) -> str:
"""Save single result to file (append mode)"""
"""Save single result to file (append mode)
Args:
result: The result to save
filename: File path (can be absolute or relative to output_dir)
Returns:
str: Path to the saved file
"""
# Handle both absolute and relative paths
filepath = Path(filename)
if not filepath.is_absolute():
filepath = self.output_dir / filename
# Atomic write: write to temporary file first
@@ -590,7 +601,7 @@ class BatchSaveManager:
# Check if save is needed
if len(self.buffer) >= self.batch_size:
self._flush_buffer()
self._flush_unlocked()
def add_results(self, results: List[Dict]) -> None:
"""Batch add results"""
@@ -604,11 +615,18 @@ class BatchSaveManager:
self._save_batch(batch)
def _flush_buffer(self) -> None:
"""Save all results in buffer"""
"""Save all results in buffer (thread-safe public method)"""
with self.lock:
self._flush_unlocked()
def _flush_unlocked(self) -> None:
"""Internal flush method, caller must hold lock.
This is the actual implementation that flushes the buffer.
It must only be called while holding self.lock.
"""
if not self.buffer:
return
with self.lock:
batch = self.buffer.copy()
self.buffer.clear()
self._save_batch(batch)

View File

@@ -21,20 +21,17 @@ class TestBaseClasses:
test_case_id="test", prompt="test", image_path="test", metadata={}
)
# Test case requiring local model loading
attack_with_model = TestAttack(config={"model_path": "/path/to/model"})
# Test with explicit load_model=True
attack_with_model = TestAttack(config={"load_model": True})
assert attack_with_model.load_model is True
# Test case not requiring local model loading
attack_without_model = TestAttack(config={})
# Test with explicit load_model=False
attack_without_model = TestAttack(config={"load_model": False})
assert attack_without_model.load_model is False
# Test other local model configuration items
attack_with_device = TestAttack(config={"device": "cuda:0"})
assert attack_with_device.load_model is True
attack_with_checkpoint = TestAttack(config={"checkpoint": "model.ckpt"})
assert attack_with_checkpoint.load_model is True
# Test without load_model field (defaults to False)
attack_default = TestAttack(config={})
assert attack_default.load_model is False
def test_base_component_required_field_validation(self):
"""Test required field validation (error when missing, normal when provided)"""
@@ -73,14 +70,18 @@ class TestBaseClasses:
def apply_defense(self, test_case, **kwargs):
return test_case
# Test case requiring local model loading
defense_with_model = TestDefense(config={"local_model": True})
# Test with explicit load_model=True
defense_with_model = TestDefense(config={"load_model": True})
assert defense_with_model.load_model is True
# Test case not requiring local model loading
defense_without_model = TestDefense(config={})
# Test with explicit load_model=False
defense_without_model = TestDefense(config={"load_model": False})
assert defense_without_model.load_model is False
# Test without load_model field (defaults to False)
defense_default = TestDefense(config={})
assert defense_default.load_model is False
def test_base_model_model_type_detection(self):
"""Test model type detection"""
from models.base_model import BaseModel

View File

@@ -5,6 +5,8 @@ Test pipeline system
import pytest
import tempfile
import json
import threading
import time
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
@@ -133,22 +135,26 @@ class TestBasePipeline:
pipeline_config = PipelineConfig(**test_config)
pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation")
# Create a temporary file just to get a unique name
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_file = Path(f.name)
temp_file_name = Path(f.name).name
# The actual file will be created in pipeline's output_dir
output_file = pipeline.output_dir / temp_file_name
try:
first = {"test_case_id": "case_1", "data": 1}
second = {"test_case_id": "case_1", "data": 2}
pipeline.save_single_result(first, temp_file.name)
pipeline.save_single_result(second, temp_file.name)
pipeline.save_single_result(first, temp_file_name)
pipeline.save_single_result(second, temp_file_name)
loaded = pipeline.load_results(temp_file)
loaded = pipeline.load_results(output_file)
assert len(loaded) == 1
assert loaded[0]["data"] == 2
finally:
if temp_file.exists():
temp_file.unlink()
if output_file.exists():
output_file.unlink()
def test_task_hash(self, test_config):
"""Test task hash generation"""
@@ -232,6 +238,9 @@ class TestBatchSaveManager:
# Verify buffer has been cleared
assert len(manager.buffer) == 0
assert manager.total_saved == 2
finally:
if temp_file.exists():
temp_file.unlink()
def test_batch_save_flush(self, test_config):
"""Test flush saves remaining buffer"""
@@ -261,10 +270,6 @@ class TestBatchSaveManager:
if temp_file.exists():
temp_file.unlink()
finally:
if temp_file.exists():
temp_file.unlink()
def test_batch_save_context_manager(self, test_config):
"""Test batch save context manager"""
from pipeline.base_pipeline import batch_save_context
@@ -298,6 +303,110 @@ class TestBatchSaveManager:
if temp_file.exists():
temp_file.unlink()
def test_batch_save_no_deadlock_when_exceeding_batch_size(self, test_config):
"""Test that adding results equal to batch_size doesn't cause deadlock.
This test verifies the fix for the deadlock bug where:
- add_result() acquires self.lock
- When buffer >= batch_size, _flush_buffer() is called
- _flush_buffer() tries to acquire self.lock again → DEADLOCK
The fix uses _flush_unlocked() which assumes caller holds the lock.
"""
from pipeline.base_pipeline import BatchSaveManager
from core.data_formats import PipelineConfig
import tempfile
pipeline_config = PipelineConfig(**test_config)
pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation")
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_file = Path(f.name)
try:
manager = BatchSaveManager(
pipeline=pipeline, output_filename=temp_file, batch_size=3
)
# Add exactly batch_size results - this would trigger the deadlock in the bug
results = [{"id": i, "data": f"result_{i}"} for i in range(3)]
# Use a timeout to catch potential deadlocks
def add_results_with_timeout():
for result in results:
manager.add_result(result)
thread = threading.Thread(target=add_results_with_timeout)
thread.start()
# Wait up to 5 seconds - if deadlock occurs, this will timeout
thread.join(timeout=5.0)
assert not thread.is_alive(), "Thread is still alive - likely deadlock!"
assert manager.total_saved == 3, f"Expected 3 saved, got {manager.total_saved}"
assert len(manager.buffer) == 0, "Buffer should be empty after batch save"
finally:
if temp_file.exists():
temp_file.unlink()
def test_batch_save_concurrent_additions(self, test_config):
"""Test thread-safe batch saving with concurrent additions"""
from pipeline.base_pipeline import BatchSaveManager
from core.data_formats import PipelineConfig
import tempfile
pipeline_config = PipelineConfig(**test_config)
pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation")
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_file = Path(f.name)
try:
manager = BatchSaveManager(
pipeline=pipeline, output_filename=temp_file, batch_size=10
)
num_threads = 4
results_per_thread = 25
exceptions = []
def add_results(thread_id):
try:
for i in range(results_per_thread):
manager.add_result(
{"thread_id": thread_id, "index": i, "data": f"t{thread_id}_r{i}"}
)
# Small random delay to increase contention
time.sleep(0.001)
except Exception as e:
exceptions.append((thread_id, e))
threads = []
for i in range(num_threads):
t = threading.Thread(target=add_results, args=(i,))
threads.append(t)
t.start()
# Wait for all threads with timeout
for t in threads:
t.join(timeout=30.0)
# Check no thread is stuck (deadlock)
assert not any(t.is_alive() for t in threads), "Some threads are still alive - likely deadlock!"
# Check no exceptions occurred
assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}"
# Verify all results were saved
expected_total = num_threads * results_per_thread
assert manager.total_saved >= expected_total - manager.batch_size, \
f"Expected at least {expected_total - manager.batch_size} saved, got {manager.total_saved}"
finally:
if temp_file.exists():
temp_file.unlink()
class TestParallelProcessing:
"""Test parallel processing"""