mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-02-12 17:52:46 +00:00
refactor: consolidate duplicate code and fix deadlock bug
- Extract duplicate attack/defense config merging into _merge_component_configs() - Extract duplicate lazy loading logic into _get_component() - Move content policy detection to BaseModel base class - Fix BatchSaveManager deadlock by splitting flush logic - Add TypeError to ValueError conversion for consistent config errors - Move _determine_load_model() to BaseComponent (explicit field only)
This commit is contained in:
@@ -188,27 +188,11 @@ class ConfigLoader:
|
||||
attack_names = test_case_cfg["attacks"]
|
||||
attack_params = test_case_cfg.get("attack_params", {}) or {}
|
||||
|
||||
# Merge attack configurations
|
||||
merged_attack_params = {}
|
||||
for attack_name in attack_names:
|
||||
try:
|
||||
attack_config = self.load_attack_config(attack_name)
|
||||
base_params = attack_config.get("parameters", {})
|
||||
|
||||
# If there are override parameters in general config, merge them
|
||||
if attack_name in attack_params:
|
||||
merged_attack_params[attack_name] = self._deep_merge(
|
||||
base_params, attack_params[attack_name]
|
||||
test_case_cfg["attack_params"] = self._merge_component_configs(
|
||||
component_names=attack_names,
|
||||
override_params=attack_params,
|
||||
config_loader_func=self.load_attack_config,
|
||||
)
|
||||
else:
|
||||
merged_attack_params[attack_name] = base_params
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"Warning: Attack configuration file does not exist: {attack_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
test_case_cfg["attack_params"] = merged_attack_params
|
||||
|
||||
# Process response generation configuration
|
||||
if "response_generation" in full_config:
|
||||
@@ -245,31 +229,12 @@ class ConfigLoader:
|
||||
defense_names = response_cfg["defenses"]
|
||||
defense_overrides = response_cfg.get("defense_params", {}) or {}
|
||||
|
||||
# Merge defense configurations
|
||||
merged_defense_params = {}
|
||||
for defense_name in defense_names:
|
||||
if defense_name == "None":
|
||||
merged_defense_params[defense_name] = {}
|
||||
continue
|
||||
|
||||
try:
|
||||
defense_config = self.load_defense_config(defense_name)
|
||||
base_params = defense_config.get("parameters", {})
|
||||
|
||||
# If there are override parameters in general config, merge them
|
||||
if defense_name in defense_overrides:
|
||||
merged_defense_params[defense_name] = self._deep_merge(
|
||||
base_params, defense_overrides[defense_name]
|
||||
response_cfg["defense_params"] = self._merge_component_configs(
|
||||
component_names=defense_names,
|
||||
override_params=defense_overrides,
|
||||
config_loader_func=self.load_defense_config,
|
||||
skip_value="None",
|
||||
)
|
||||
else:
|
||||
merged_defense_params[defense_name] = base_params
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"Warning: Defense configuration file does not exist: {defense_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
response_cfg["defense_params"] = merged_defense_params
|
||||
|
||||
# Process evaluation configuration
|
||||
if "evaluation" in full_config:
|
||||
@@ -361,6 +326,50 @@ class ConfigLoader:
|
||||
|
||||
return result
|
||||
|
||||
def _merge_component_configs(
|
||||
self,
|
||||
component_names: List[str],
|
||||
override_params: Dict[str, Any],
|
||||
config_loader_func, # Callable: takes name, returns config dict
|
||||
skip_value: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge component configurations from files with runtime overrides.
|
||||
|
||||
Args:
|
||||
component_names: List of component names to process
|
||||
override_params: Runtime override parameters from general_config
|
||||
config_loader_func: Function to load base config for a component
|
||||
skip_value: Special value that should be skipped with empty config
|
||||
|
||||
Returns:
|
||||
Merged configuration dictionary
|
||||
"""
|
||||
merged_params = {}
|
||||
|
||||
for component_name in component_names:
|
||||
if skip_value is not None and component_name == skip_value:
|
||||
merged_params[component_name] = {}
|
||||
continue
|
||||
|
||||
try:
|
||||
base_config = config_loader_func(component_name)
|
||||
base_params = base_config.get("parameters", {})
|
||||
|
||||
if component_name in override_params:
|
||||
merged_params[component_name] = self._deep_merge(
|
||||
base_params, override_params[component_name]
|
||||
)
|
||||
else:
|
||||
merged_params[component_name] = base_params
|
||||
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"Warning: Component configuration file does not exist: {component_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
return merged_params
|
||||
|
||||
def _load_yaml_file(self, file_path: Path) -> Dict[str, Any]:
|
||||
"""Load YAML file"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
|
||||
@@ -47,7 +47,11 @@ class BaseComponent(ABC):
|
||||
cfg_dict = config
|
||||
allowed_fields = {f.name for f in fields(self.CONFIG_CLASS)}
|
||||
filtered = {k: v for k, v in cfg_dict.items() if k in allowed_fields}
|
||||
try:
|
||||
self.cfg = self.CONFIG_CLASS(**filtered)
|
||||
except TypeError as e:
|
||||
# Convert TypeError to ValueError for consistency
|
||||
raise ValueError(f"Invalid configuration: {e}") from e
|
||||
else:
|
||||
# No configuration class or not a dataclass, use dictionary directly
|
||||
self.cfg = config
|
||||
@@ -77,6 +81,24 @@ class BaseComponent(ABC):
|
||||
pass
|
||||
# Subclasses can add more validation logic
|
||||
|
||||
def _determine_load_model(self) -> bool:
|
||||
"""Determine if local model needs to be loaded.
|
||||
|
||||
Checks the load_model field in configuration (dict or dataclass).
|
||||
Returns False if load_model is not explicitly set.
|
||||
"""
|
||||
if hasattr(self, "cfg"):
|
||||
config_obj = self.cfg
|
||||
else:
|
||||
config_obj = self.config
|
||||
|
||||
if isinstance(config_obj, dict):
|
||||
return config_obj.get("load_model", False)
|
||||
elif hasattr(config_obj, "load_model"):
|
||||
return getattr(config_obj, "load_model", False)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class BaseAttack(BaseComponent, ABC):
|
||||
"""Attack method base class (enhanced version)"""
|
||||
@@ -109,25 +131,6 @@ class BaseAttack(BaseComponent, ABC):
|
||||
# Determine if local model needs to be loaded
|
||||
self.load_model = self._determine_load_model()
|
||||
|
||||
def _determine_load_model(self) -> bool:
|
||||
"""Determine if local model needs to be loaded
|
||||
|
||||
Only decide based on the load_model field in configuration
|
||||
"""
|
||||
# Check if load_model field exists in configuration
|
||||
if hasattr(self, "cfg"):
|
||||
config_obj = self.cfg
|
||||
else:
|
||||
config_obj = self.config
|
||||
|
||||
# Only check load_model configuration item
|
||||
if isinstance(config_obj, dict):
|
||||
return config_obj.get("load_model", False)
|
||||
elif hasattr(config_obj, "load_model"):
|
||||
return getattr(config_obj, "load_model", False)
|
||||
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def generate_test_case(
|
||||
self, original_prompt: str, image_path: str, case_id: str, **kwargs
|
||||
@@ -259,25 +262,6 @@ class BaseDefense(BaseComponent, ABC):
|
||||
# Determine if local model needs to be loaded
|
||||
self.load_model = self._determine_load_model()
|
||||
|
||||
def _determine_load_model(self) -> bool:
|
||||
"""Determine if local model needs to be loaded
|
||||
|
||||
Only based on the load_model field in configuration
|
||||
"""
|
||||
# Check if load_model field exists in configuration
|
||||
if hasattr(self, "cfg"):
|
||||
config_obj = self.cfg
|
||||
else:
|
||||
config_obj = self.config
|
||||
|
||||
# Only check load_model configuration item
|
||||
if isinstance(config_obj, dict):
|
||||
return config_obj.get("load_model", False)
|
||||
elif hasattr(config_obj, "load_model"):
|
||||
return getattr(config_obj, "load_model", False)
|
||||
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def apply_defense(self, test_case: TestCase, **kwargs) -> TestCase:
|
||||
"""
|
||||
|
||||
@@ -109,124 +109,79 @@ class UnifiedRegistry:
|
||||
self.evaluator_registry[name] = evaluator_class
|
||||
self.logger.debug(f"Registered evaluator: {name}")
|
||||
|
||||
def get_attack(self, name: str) -> Optional[Type["BaseAttack"]]:
|
||||
"""Get attack method class"""
|
||||
if name in self.attack_registry:
|
||||
return self.attack_registry[name]
|
||||
def _get_component(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str, # "attacks", "models", "defenses", "evaluators"
|
||||
) -> Optional[Type]:
|
||||
"""Generic component getter with lazy loading.
|
||||
|
||||
# Get mapping information from config/plugins.yaml
|
||||
Args:
|
||||
name: Component name
|
||||
component_type: One of "attacks", "models", "defenses", "evaluators"
|
||||
|
||||
Returns:
|
||||
Component class or None if not found
|
||||
"""
|
||||
type_info = {
|
||||
"attacks": ("attack_registry", "attack method"),
|
||||
"models": ("model_registry", "model"),
|
||||
"defenses": ("defense_registry", "defense method"),
|
||||
"evaluators": ("evaluator_registry", "evaluator"),
|
||||
}
|
||||
|
||||
if component_type not in type_info:
|
||||
raise ValueError(f"Invalid component_type: {component_type}")
|
||||
|
||||
registry_attr, type_name = type_info[component_type]
|
||||
registry = getattr(self, registry_attr)
|
||||
|
||||
# Check cache first
|
||||
if name in registry:
|
||||
return registry[name]
|
||||
|
||||
# Try lazy loading from plugins.yaml
|
||||
try:
|
||||
mappings = self._get_lazy_mappings()
|
||||
if name in mappings["attacks"]:
|
||||
module_path, class_name = mappings["attacks"][name]
|
||||
if name in mappings[component_type]:
|
||||
module_path, class_name = mappings[component_type][name]
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
# Register to cache
|
||||
self.attack_registry[name] = cls
|
||||
# Cache for future access
|
||||
registry[name] = cls
|
||||
self.logger.debug(
|
||||
f"Successfully imported attack method from mapping: {name}"
|
||||
f"Successfully imported {type_name} from mapping: {name}"
|
||||
)
|
||||
return cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
self.logger.debug(
|
||||
f"Unable to import attack method '{name}' from mapping: {e}"
|
||||
f"Unable to import {type_name} '{name}' from mapping: {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Unknown error occurred while importing attack method '{name}': {e}"
|
||||
f"Unknown error occurred while importing {type_name} '{name}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.warning(f"Attack method '{name}' is not defined in mapping")
|
||||
self.logger.warning(f"{type_name.capitalize()} '{name}' is not defined in mapping")
|
||||
return None
|
||||
|
||||
def get_attack(self, name: str) -> Optional[Type["BaseAttack"]]:
|
||||
"""Get attack method class"""
|
||||
return self._get_component(name, "attacks")
|
||||
|
||||
def get_model(self, name: str) -> Optional[Type["BaseModel"]]:
|
||||
"""Get model class"""
|
||||
if name in self.model_registry:
|
||||
return self.model_registry[name]
|
||||
|
||||
# Get mapping information from config/plugins.yaml
|
||||
try:
|
||||
mappings = self._get_lazy_mappings()
|
||||
if name in mappings["models"]:
|
||||
module_path, class_name = mappings["models"][name]
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
self.model_registry[name] = cls
|
||||
self.logger.debug(f"Successfully imported model from mapping: {name}")
|
||||
return cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
self.logger.debug(f"Unable to import model '{name}' from mapping: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Unknown error occurred while importing model '{name}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.warning(f"Model '{name}' is not defined in mapping")
|
||||
return None
|
||||
return self._get_component(name, "models")
|
||||
|
||||
def get_defense(self, name: str) -> Optional[Type["BaseDefense"]]:
|
||||
"""Get defense method class"""
|
||||
if name in self.defense_registry:
|
||||
return self.defense_registry[name]
|
||||
|
||||
# Get mapping information from config/plugins.yaml
|
||||
try:
|
||||
mappings = self._get_lazy_mappings()
|
||||
if name in mappings["defenses"]:
|
||||
module_path, class_name = mappings["defenses"][name]
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
self.defense_registry[name] = cls
|
||||
self.logger.debug(
|
||||
f"Successfully imported defense method from mapping: {name}"
|
||||
)
|
||||
return cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
self.logger.debug(
|
||||
f"Unable to import defense method '{name}' from mapping: {e}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Unknown error occurred while importing defense method '{name}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.warning(f"Defense method '{name}' is not defined in mapping")
|
||||
return None
|
||||
return self._get_component(name, "defenses")
|
||||
|
||||
def get_evaluator(self, name: str) -> Optional[Type["BaseEvaluator"]]:
|
||||
"""Get evaluator class"""
|
||||
if name in self.evaluator_registry:
|
||||
return self.evaluator_registry[name]
|
||||
|
||||
# Get mapping information from config/plugins.yaml
|
||||
try:
|
||||
mappings = self._get_lazy_mappings()
|
||||
if name in mappings["evaluators"]:
|
||||
module_path, class_name = mappings["evaluators"][name]
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
self.evaluator_registry[name] = cls
|
||||
self.logger.debug(
|
||||
f"Successfully imported evaluator from mapping: {name}"
|
||||
)
|
||||
return cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
self.logger.debug(f"Unable to import evaluator '{name}' from mapping: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Unknown error occurred while importing evaluator '{name}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.warning(f"Evaluator '{name}' is not defined in mapping")
|
||||
return None
|
||||
return self._get_component(name, "evaluators")
|
||||
|
||||
def create_attack(
|
||||
self, name: str, config: Dict[str, Any] = None, output_image_dir: str = None
|
||||
|
||||
@@ -7,6 +7,11 @@ class AnthropicModel(BaseModel):
|
||||
|
||||
default_output = "I'm sorry, but I cannot assist with that request."
|
||||
|
||||
# Anthropic-specific content policy keywords
|
||||
PROVIDER_SPECIFIC_KEYWORDS = [
|
||||
"output blocked by content filtering policy",
|
||||
]
|
||||
|
||||
def __init__(self, model_name: str, api_key: str) -> None:
|
||||
super().__init__(model_name, api_key)
|
||||
|
||||
@@ -53,15 +58,8 @@ class AnthropicModel(BaseModel):
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "Output blocked by content filtering policy" in str(e):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection()
|
||||
raise
|
||||
|
||||
return self._retry_with_backoff(_api_call)
|
||||
@@ -100,23 +98,8 @@ class AnthropicModel(BaseModel):
|
||||
)
|
||||
return stream
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "Output blocked by content filtering policy" in str(e):
|
||||
# Return a generator that yields the content rejection placeholder
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection_stream()
|
||||
raise
|
||||
|
||||
try:
|
||||
@@ -128,4 +111,4 @@ class AnthropicModel(BaseModel):
|
||||
elif hasattr(chunk, "completion") and chunk.completion:
|
||||
yield chunk.completion
|
||||
except Exception:
|
||||
yield self.API_ERROR_OUTPUT
|
||||
yield self._handle_api_error()
|
||||
|
||||
@@ -18,6 +18,21 @@ class BaseModel(CoreBaseModel):
|
||||
API_MAX_RETRY = 3
|
||||
API_TIMEOUT = 600
|
||||
|
||||
# Content policy detection keywords (common across providers)
|
||||
CONTENT_POLICY_KEYWORDS = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
|
||||
# Provider-specific additional keywords (subclasses can override)
|
||||
PROVIDER_SPECIFIC_KEYWORDS = []
|
||||
|
||||
def __init__(self, model_name: str, api_key: str = None, base_url: str = None):
|
||||
# Call parent class __init__, pass empty configuration
|
||||
super().__init__(config={})
|
||||
@@ -76,6 +91,46 @@ class BaseModel(CoreBaseModel):
|
||||
else:
|
||||
return "local"
|
||||
|
||||
def _is_content_policy_rejection(self, error: Exception) -> bool:
|
||||
"""Check if an exception represents a content policy rejection.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error indicates content policy rejection
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
# Combine common and provider-specific keywords
|
||||
all_keywords = self.CONTENT_POLICY_KEYWORDS + self.PROVIDER_SPECIFIC_KEYWORDS
|
||||
|
||||
return any(keyword in error_str for keyword in all_keywords)
|
||||
|
||||
def _handle_content_rejection(self) -> str:
|
||||
"""Return the standard content rejection output."""
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
def _handle_content_rejection_stream(self):
|
||||
"""Return a generator that yields the content rejection placeholder."""
|
||||
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
|
||||
def _handle_api_error(self) -> str:
|
||||
"""Return the standard API error output."""
|
||||
return self.API_ERROR_OUTPUT
|
||||
|
||||
def _handle_api_error_stream(self):
|
||||
"""Return a generator that yields the API error placeholder."""
|
||||
|
||||
def error_generator():
|
||||
yield self.API_ERROR_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
|
||||
def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
"""Execute function with retry logic and exponential backoff using backoff library."""
|
||||
|
||||
|
||||
@@ -33,20 +33,8 @@ class DoubaoModel(BaseModel):
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# Check for content policy violations in ByteDance models
|
||||
error_str = str(e).lower()
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection()
|
||||
raise
|
||||
|
||||
return self._retry_with_backoff(_api_call)
|
||||
@@ -68,34 +56,8 @@ class DoubaoModel(BaseModel):
|
||||
)
|
||||
return stream
|
||||
except Exception as e:
|
||||
# Check for content policy violations in ByteDance models
|
||||
error_str = str(e).lower()
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
# Return a generator that yields the content rejection placeholder
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection_stream()
|
||||
raise
|
||||
|
||||
try:
|
||||
@@ -104,4 +66,4 @@ class DoubaoModel(BaseModel):
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception:
|
||||
yield self.API_ERROR_OUTPUT
|
||||
yield self._handle_api_error()
|
||||
|
||||
@@ -7,6 +7,11 @@ class GoogleModel(BaseModel):
|
||||
|
||||
default_output = "I'm sorry, but I cannot assist with that request."
|
||||
|
||||
# Google-specific content policy keywords
|
||||
PROVIDER_SPECIFIC_KEYWORDS = [
|
||||
"blocked",
|
||||
]
|
||||
|
||||
def __init__(self, model_name: str, api_key: str, base_url: str = None) -> None:
|
||||
super().__init__(model_name=model_name, api_key=api_key, base_url=base_url)
|
||||
|
||||
@@ -37,27 +42,8 @@ class GoogleModel(BaseModel):
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# Check for content policy violations in Gemini models
|
||||
error_str = str(e).lower()
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"blocked",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection()
|
||||
raise
|
||||
|
||||
return self._retry_with_backoff(_api_call)
|
||||
@@ -86,4 +72,4 @@ class GoogleModel(BaseModel):
|
||||
if chunk.text:
|
||||
yield chunk.text
|
||||
except Exception:
|
||||
yield self.API_ERROR_OUTPUT
|
||||
yield self._handle_api_error()
|
||||
|
||||
@@ -47,26 +47,8 @@ class MistralModel(BaseModel):
|
||||
)
|
||||
return chat_response
|
||||
except Exception as e:
|
||||
# Check for content policy violations in Mistral models
|
||||
error_str = str(e).lower()
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection()
|
||||
raise
|
||||
|
||||
return self._retry_with_backoff(_api_call)
|
||||
@@ -99,34 +81,8 @@ class MistralModel(BaseModel):
|
||||
)
|
||||
return stream
|
||||
except Exception as e:
|
||||
# Check for content policy violations in Mistral models
|
||||
error_str = str(e).lower()
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
# Return a generator that yields the content rejection placeholder
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection_stream()
|
||||
raise
|
||||
|
||||
try:
|
||||
@@ -136,4 +92,4 @@ class MistralModel(BaseModel):
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception:
|
||||
yield self.API_ERROR_OUTPUT
|
||||
yield self._handle_api_error()
|
||||
|
||||
@@ -5,6 +5,14 @@ from .base_model import BaseModel
|
||||
class OpenAIModel(BaseModel):
|
||||
"""OpenAI model implementation using OpenAI API."""
|
||||
|
||||
# OpenAI-specific content policy keywords
|
||||
PROVIDER_SPECIFIC_KEYWORDS = [
|
||||
"invalid",
|
||||
"inappropriate",
|
||||
"invalid_prompt",
|
||||
"limited access",
|
||||
]
|
||||
|
||||
def __init__(self, model_name: str, api_key: str, base_url: Optional[str] = None):
|
||||
super().__init__(model_name=model_name, api_key=api_key, base_url=base_url)
|
||||
|
||||
@@ -41,27 +49,11 @@ class OpenAIModel(BaseModel):
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Check for content policy violations in GPT models
|
||||
error_str = str(e).lower()
|
||||
print("Error during API call:", error_str)
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"invalid",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate",
|
||||
"invalid_prompt",
|
||||
"limited access",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
print("✓ Content rejection triggered")
|
||||
return self.API_CONTENT_REJECTION_OUTPUT
|
||||
print("✗ No content keywords matched, raising exception")
|
||||
raise e
|
||||
print("Error during API call:", str(e).lower())
|
||||
if self._is_content_policy_rejection(e):
|
||||
print("Content rejection triggered")
|
||||
return self._handle_content_rejection()
|
||||
raise
|
||||
|
||||
return self._retry_with_backoff(_api_call)
|
||||
|
||||
@@ -86,34 +78,8 @@ class OpenAIModel(BaseModel):
|
||||
)
|
||||
return stream
|
||||
except Exception as e:
|
||||
# Check for content policy violations in GPT models
|
||||
error_str = str(e).lower()
|
||||
content_keywords = [
|
||||
"content policy",
|
||||
"safety",
|
||||
"harmful",
|
||||
"unsafe",
|
||||
"violation",
|
||||
"moderation",
|
||||
"data_inspection_failed",
|
||||
"inappropriate content",
|
||||
]
|
||||
if any(keyword in error_str for keyword in content_keywords):
|
||||
# Return a generator that yields the content rejection placeholder
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
# Handle BadRequestError specifically
|
||||
if (
|
||||
"badrequesterror" in error_str
|
||||
and "data_inspection_failed" in error_str
|
||||
):
|
||||
|
||||
def error_generator():
|
||||
yield self.API_CONTENT_REJECTION_OUTPUT
|
||||
|
||||
return error_generator()
|
||||
if self._is_content_policy_rejection(e):
|
||||
return self._handle_content_rejection_stream()
|
||||
raise
|
||||
|
||||
try:
|
||||
@@ -122,4 +88,4 @@ class OpenAIModel(BaseModel):
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception:
|
||||
yield self.API_ERROR_OUTPUT
|
||||
yield self._handle_api_error()
|
||||
|
||||
@@ -465,7 +465,18 @@ class BasePipeline(ABC):
|
||||
return str(filepath)
|
||||
|
||||
def save_single_result(self, result: Dict, filename: str) -> str:
|
||||
"""Save single result to file (append mode)"""
|
||||
"""Save single result to file (append mode)
|
||||
|
||||
Args:
|
||||
result: The result to save
|
||||
filename: File path (can be absolute or relative to output_dir)
|
||||
|
||||
Returns:
|
||||
str: Path to the saved file
|
||||
"""
|
||||
# Handle both absolute and relative paths
|
||||
filepath = Path(filename)
|
||||
if not filepath.is_absolute():
|
||||
filepath = self.output_dir / filename
|
||||
|
||||
# Atomic write: write to temporary file first
|
||||
@@ -590,7 +601,7 @@ class BatchSaveManager:
|
||||
|
||||
# Check if save is needed
|
||||
if len(self.buffer) >= self.batch_size:
|
||||
self._flush_buffer()
|
||||
self._flush_unlocked()
|
||||
|
||||
def add_results(self, results: List[Dict]) -> None:
|
||||
"""Batch add results"""
|
||||
@@ -604,11 +615,18 @@ class BatchSaveManager:
|
||||
self._save_batch(batch)
|
||||
|
||||
def _flush_buffer(self) -> None:
|
||||
"""Save all results in buffer"""
|
||||
"""Save all results in buffer (thread-safe public method)"""
|
||||
with self.lock:
|
||||
self._flush_unlocked()
|
||||
|
||||
def _flush_unlocked(self) -> None:
|
||||
"""Internal flush method, caller must hold lock.
|
||||
|
||||
This is the actual implementation that flushes the buffer.
|
||||
It must only be called while holding self.lock.
|
||||
"""
|
||||
if not self.buffer:
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
batch = self.buffer.copy()
|
||||
self.buffer.clear()
|
||||
self._save_batch(batch)
|
||||
|
||||
@@ -21,20 +21,17 @@ class TestBaseClasses:
|
||||
test_case_id="test", prompt="test", image_path="test", metadata={}
|
||||
)
|
||||
|
||||
# Test case requiring local model loading
|
||||
attack_with_model = TestAttack(config={"model_path": "/path/to/model"})
|
||||
# Test with explicit load_model=True
|
||||
attack_with_model = TestAttack(config={"load_model": True})
|
||||
assert attack_with_model.load_model is True
|
||||
|
||||
# Test case not requiring local model loading
|
||||
attack_without_model = TestAttack(config={})
|
||||
# Test with explicit load_model=False
|
||||
attack_without_model = TestAttack(config={"load_model": False})
|
||||
assert attack_without_model.load_model is False
|
||||
|
||||
# Test other local model configuration items
|
||||
attack_with_device = TestAttack(config={"device": "cuda:0"})
|
||||
assert attack_with_device.load_model is True
|
||||
|
||||
attack_with_checkpoint = TestAttack(config={"checkpoint": "model.ckpt"})
|
||||
assert attack_with_checkpoint.load_model is True
|
||||
# Test without load_model field (defaults to False)
|
||||
attack_default = TestAttack(config={})
|
||||
assert attack_default.load_model is False
|
||||
|
||||
def test_base_component_required_field_validation(self):
|
||||
"""Test required field validation (error when missing, normal when provided)"""
|
||||
@@ -73,14 +70,18 @@ class TestBaseClasses:
|
||||
def apply_defense(self, test_case, **kwargs):
|
||||
return test_case
|
||||
|
||||
# Test case requiring local model loading
|
||||
defense_with_model = TestDefense(config={"local_model": True})
|
||||
# Test with explicit load_model=True
|
||||
defense_with_model = TestDefense(config={"load_model": True})
|
||||
assert defense_with_model.load_model is True
|
||||
|
||||
# Test case not requiring local model loading
|
||||
defense_without_model = TestDefense(config={})
|
||||
# Test with explicit load_model=False
|
||||
defense_without_model = TestDefense(config={"load_model": False})
|
||||
assert defense_without_model.load_model is False
|
||||
|
||||
# Test without load_model field (defaults to False)
|
||||
defense_default = TestDefense(config={})
|
||||
assert defense_default.load_model is False
|
||||
|
||||
def test_base_model_model_type_detection(self):
|
||||
"""Test model type detection"""
|
||||
from models.base_model import BaseModel
|
||||
|
||||
@@ -5,6 +5,8 @@ Test pipeline system
|
||||
import pytest
|
||||
import tempfile
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
@@ -133,22 +135,26 @@ class TestBasePipeline:
|
||||
pipeline_config = PipelineConfig(**test_config)
|
||||
pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation")
|
||||
|
||||
# Create a temporary file just to get a unique name
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
temp_file = Path(f.name)
|
||||
temp_file_name = Path(f.name).name
|
||||
|
||||
# The actual file will be created in pipeline's output_dir
|
||||
output_file = pipeline.output_dir / temp_file_name
|
||||
|
||||
try:
|
||||
first = {"test_case_id": "case_1", "data": 1}
|
||||
second = {"test_case_id": "case_1", "data": 2}
|
||||
|
||||
pipeline.save_single_result(first, temp_file.name)
|
||||
pipeline.save_single_result(second, temp_file.name)
|
||||
pipeline.save_single_result(first, temp_file_name)
|
||||
pipeline.save_single_result(second, temp_file_name)
|
||||
|
||||
loaded = pipeline.load_results(temp_file)
|
||||
loaded = pipeline.load_results(output_file)
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0]["data"] == 2
|
||||
finally:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
if output_file.exists():
|
||||
output_file.unlink()
|
||||
|
||||
def test_task_hash(self, test_config):
|
||||
"""Test task hash generation"""
|
||||
@@ -232,6 +238,9 @@ class TestBatchSaveManager:
|
||||
# Verify buffer has been cleared
|
||||
assert len(manager.buffer) == 0
|
||||
assert manager.total_saved == 2
|
||||
finally:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
|
||||
def test_batch_save_flush(self, test_config):
|
||||
"""Test flush saves remaining buffer"""
|
||||
@@ -261,10 +270,6 @@ class TestBatchSaveManager:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
|
||||
finally:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
|
||||
def test_batch_save_context_manager(self, test_config):
|
||||
"""Test batch save context manager"""
|
||||
from pipeline.base_pipeline import batch_save_context
|
||||
@@ -298,6 +303,110 @@ class TestBatchSaveManager:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
|
||||
def test_batch_save_no_deadlock_when_exceeding_batch_size(self, test_config):
|
||||
"""Test that adding results equal to batch_size doesn't cause deadlock.
|
||||
|
||||
This test verifies the fix for the deadlock bug where:
|
||||
- add_result() acquires self.lock
|
||||
- When buffer >= batch_size, _flush_buffer() is called
|
||||
- _flush_buffer() tries to acquire self.lock again → DEADLOCK
|
||||
|
||||
The fix uses _flush_unlocked() which assumes caller holds the lock.
|
||||
"""
|
||||
from pipeline.base_pipeline import BatchSaveManager
|
||||
from core.data_formats import PipelineConfig
|
||||
import tempfile
|
||||
|
||||
pipeline_config = PipelineConfig(**test_config)
|
||||
pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation")
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
temp_file = Path(f.name)
|
||||
|
||||
try:
|
||||
manager = BatchSaveManager(
|
||||
pipeline=pipeline, output_filename=temp_file, batch_size=3
|
||||
)
|
||||
|
||||
# Add exactly batch_size results - this would trigger the deadlock in the bug
|
||||
results = [{"id": i, "data": f"result_{i}"} for i in range(3)]
|
||||
|
||||
# Use a timeout to catch potential deadlocks
|
||||
def add_results_with_timeout():
|
||||
for result in results:
|
||||
manager.add_result(result)
|
||||
|
||||
thread = threading.Thread(target=add_results_with_timeout)
|
||||
thread.start()
|
||||
|
||||
# Wait up to 5 seconds - if deadlock occurs, this will timeout
|
||||
thread.join(timeout=5.0)
|
||||
|
||||
assert not thread.is_alive(), "Thread is still alive - likely deadlock!"
|
||||
assert manager.total_saved == 3, f"Expected 3 saved, got {manager.total_saved}"
|
||||
assert len(manager.buffer) == 0, "Buffer should be empty after batch save"
|
||||
|
||||
finally:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
|
||||
def test_batch_save_concurrent_additions(self, test_config):
|
||||
"""Test thread-safe batch saving with concurrent additions"""
|
||||
from pipeline.base_pipeline import BatchSaveManager
|
||||
from core.data_formats import PipelineConfig
|
||||
import tempfile
|
||||
|
||||
pipeline_config = PipelineConfig(**test_config)
|
||||
pipeline = create_concrete_pipeline(pipeline_config, "test_case_generation")
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
temp_file = Path(f.name)
|
||||
|
||||
try:
|
||||
manager = BatchSaveManager(
|
||||
pipeline=pipeline, output_filename=temp_file, batch_size=10
|
||||
)
|
||||
|
||||
num_threads = 4
|
||||
results_per_thread = 25
|
||||
exceptions = []
|
||||
|
||||
def add_results(thread_id):
|
||||
try:
|
||||
for i in range(results_per_thread):
|
||||
manager.add_result(
|
||||
{"thread_id": thread_id, "index": i, "data": f"t{thread_id}_r{i}"}
|
||||
)
|
||||
# Small random delay to increase contention
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
exceptions.append((thread_id, e))
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
t = threading.Thread(target=add_results, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads with timeout
|
||||
for t in threads:
|
||||
t.join(timeout=30.0)
|
||||
|
||||
# Check no thread is stuck (deadlock)
|
||||
assert not any(t.is_alive() for t in threads), "Some threads are still alive - likely deadlock!"
|
||||
|
||||
# Check no exceptions occurred
|
||||
assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}"
|
||||
|
||||
# Verify all results were saved
|
||||
expected_total = num_threads * results_per_thread
|
||||
assert manager.total_saved >= expected_total - manager.batch_size, \
|
||||
f"Expected at least {expected_total - manager.batch_size} saved, got {manager.total_saved}"
|
||||
|
||||
finally:
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
|
||||
|
||||
class TestParallelProcessing:
|
||||
"""Test parallel processing"""
|
||||
|
||||
Reference in New Issue
Block a user