Files
OmniSafeBench-MM/config/config_loader.py
Liao, Jie 04a1cbe8d1 refactor: consolidate duplicate code and fix deadlock bug
- Extract duplicate attack/defense config merging into _merge_component_configs()
- Extract duplicate lazy loading logic into _get_component()
- Move content policy detection to BaseModel base class
- Fix BatchSaveManager deadlock by splitting flush logic
- Add TypeError to ValueError conversion for consistent config errors
- Move _determine_load_model() to BaseComponent (explicit field only)
2026-01-23 12:55:34 +08:00

483 lines
16 KiB
Python

"""
1. general_config.yaml - General experiment configuration
2. model_config.yaml - Model detailed configuration
3. attacks/ - Attack method configuration
4. defenses/ - Defense method configuration
"""
import os
import json
import yaml
from pathlib import Path
from typing import Dict, Any, Union, List, Optional
import copy
from core.data_formats import PipelineConfig
class ConfigLoader:
"""Configuration loader"""
def __init__(self, config_dir: str = "config"):
"""
Initialize configuration loader
Args:
config_dir: Configuration directory path
"""
self.config_dir = Path(config_dir)
# Cache configurations
self._general_config = None
self._model_config = None
self._attack_configs = {}
self._defense_configs = {}
def load_general_config(
self, config_file: str = "general_config.yaml"
) -> Dict[str, Any]:
"""
Load general configuration file
Args:
config_file: General configuration file name
Returns:
General configuration dictionary
"""
if self._general_config is not None:
return self._general_config
config_path = self.config_dir / config_file
if not config_path.exists():
raise FileNotFoundError(
f"General configuration file does not exist: {config_path}"
)
self._general_config = self._load_yaml_file(config_path)
return self._general_config
def load_model_config(
self, config_file: str = "model_config.yaml"
) -> Dict[str, Any]:
"""
Load model configuration file
Args:
config_file: Model configuration file name
Returns:
Model configuration dictionary
"""
if self._model_config is not None:
return self._model_config
config_path = self.config_dir / config_file
if not config_path.exists():
raise FileNotFoundError(
f"Model configuration file does not exist: {config_path}"
)
self._model_config = self._load_yaml_file(config_path)
return self._model_config
def load_attack_config(self, attack_name: str) -> Dict[str, Any]:
"""
Load attack method configuration
Args:
attack_name: Attack method name
Returns:
Attack configuration dictionary
"""
if attack_name in self._attack_configs:
return self._attack_configs[attack_name]
config_path = self.config_dir / "attacks" / f"{attack_name}.yaml"
if not config_path.exists():
# Try .json format
config_path = self.config_dir / "attacks" / f"{attack_name}.json"
if not config_path.exists():
raise FileNotFoundError(
f"Attack configuration file does not exist: {attack_name}"
)
config = self._load_config_file(config_path)
self._attack_configs[attack_name] = config
return config
def load_defense_config(self, defense_name: str) -> Dict[str, Any]:
"""
Load defense method configuration
Args:
defense_name: Defense method name
Returns:
Defense configuration dictionary
"""
if defense_name == "None":
return {"name": "None", "description": "No defense", "parameters": {}}
if defense_name in self._defense_configs:
return self._defense_configs[defense_name]
config_path = self.config_dir / "defenses" / f"{defense_name}.yaml"
if not config_path.exists():
# Try .json format
config_path = self.config_dir / "defenses" / f"{defense_name}.json"
if not config_path.exists():
raise FileNotFoundError(
f"Defense configuration file does not exist: {defense_name}"
)
config = self._load_config_file(config_path)
self._defense_configs[defense_name] = config
return config
def load_all_configs(
self, general_config_file: str = "general_config.yaml"
) -> PipelineConfig:
"""
Load all configurations and merge into PipelineConfig
Args:
general_config_file: General configuration file name
Returns:
PipelineConfig object
"""
# Load general configuration
general_config = self.load_general_config(general_config_file)
# Load model configuration
model_config = self.load_model_config()
# Build complete configuration dictionary
full_config = self._build_full_config(general_config, model_config)
# Convert to PipelineConfig object
return PipelineConfig.from_dict(full_config)
def _build_full_config(
self, general_config: Dict[str, Any], model_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Build complete configuration dictionary
Args:
general_config: General configuration
model_config: Model configuration
Returns:
Complete configuration dictionary
"""
# Deep copy general configuration as base
full_config = copy.deepcopy(general_config)
# Process test case generation configuration
if "test_case_generation" in full_config:
test_case_cfg = full_config["test_case_generation"]
# Load attack method configurations
if "attacks" in test_case_cfg:
attack_names = test_case_cfg["attacks"]
attack_params = test_case_cfg.get("attack_params", {}) or {}
test_case_cfg["attack_params"] = self._merge_component_configs(
component_names=attack_names,
override_params=attack_params,
config_loader_func=self.load_attack_config,
)
# Process response generation configuration
if "response_generation" in full_config:
response_cfg = full_config["response_generation"]
# Process model configuration
if "models" in response_cfg:
model_names = response_cfg["models"]
model_overrides = response_cfg.get("model_params", {}) or {}
# Merge model configurations
merged_model_params = {}
for model_name in model_names:
model_info = self._find_model_config(model_name, model_config)
if model_info:
# If there are override parameters in general config, merge them
if model_name in model_overrides:
merged_model_params[model_name] = self._deep_merge(
model_info, model_overrides[model_name]
)
else:
merged_model_params[model_name] = model_info
else:
print(
f"Warning: Model configuration does not exist: {model_name}"
)
# Add empty configuration to avoid subsequent errors
merged_model_params[model_name] = {}
response_cfg["model_params"] = merged_model_params
# Process defense configuration
if "defenses" in response_cfg:
defense_names = response_cfg["defenses"]
defense_overrides = response_cfg.get("defense_params", {}) or {}
response_cfg["defense_params"] = self._merge_component_configs(
component_names=defense_names,
override_params=defense_overrides,
config_loader_func=self.load_defense_config,
skip_value="None",
)
# Process evaluation configuration
if "evaluation" in full_config:
eval_cfg = full_config["evaluation"]
# Process evaluator model configuration
if "evaluator_params" in eval_cfg:
for evaluator_name, evaluator_params in eval_cfg[
"evaluator_params"
].items():
if "model" in evaluator_params:
model_name = evaluator_params["model"]
model_info = self._find_model_config(model_name, model_config)
if model_info:
# Merge into evaluator parameters
evaluator_params.update(model_info)
return full_config
def _find_model_config(
self, model_name: str, model_config: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Find specified model in model configuration (supports new provider structure)
Args:
model_name: Model name
model_config: Model configuration dictionary
Returns:
Model configuration information, returns None if not found
"""
# Check if it's the defaults section
if model_name == "defaults" or model_name in model_config.get("defaults", {}):
return None # defaults is not a specific model configuration
# Support new provider structure
if "providers" in model_config:
# Traverse all providers to find model
for provider_name, provider_config in model_config["providers"].items():
if (
"models" in provider_config
and model_name in provider_config["models"]
):
# Get model configuration
model_info = provider_config["models"][model_name].copy()
# Inherit provider-level configuration
if "api_key" in provider_config:
model_info.setdefault("api_key", provider_config["api_key"])
if "base_url" in provider_config:
model_info.setdefault("base_url", provider_config["base_url"])
# Set provider information
model_info["provider"] = provider_name
return model_info
# Backward compatibility: directly find model (old flat structure)
elif model_name in model_config:
return model_config[model_name]
return None
def _deep_merge(
self, base: Dict[str, Any], override: Dict[str, Any]
) -> Dict[str, Any]:
"""
Deep merge two dictionaries
Args:
base: Base dictionary
override: Override dictionary
Returns:
Merged dictionary
"""
result = copy.deepcopy(base)
for key, value in override.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(value, dict)
):
result[key] = self._deep_merge(result[key], value)
else:
result[key] = value
return result
def _merge_component_configs(
self,
component_names: List[str],
override_params: Dict[str, Any],
config_loader_func, # Callable: takes name, returns config dict
skip_value: str = None,
) -> Dict[str, Any]:
"""Merge component configurations from files with runtime overrides.
Args:
component_names: List of component names to process
override_params: Runtime override parameters from general_config
config_loader_func: Function to load base config for a component
skip_value: Special value that should be skipped with empty config
Returns:
Merged configuration dictionary
"""
merged_params = {}
for component_name in component_names:
if skip_value is not None and component_name == skip_value:
merged_params[component_name] = {}
continue
try:
base_config = config_loader_func(component_name)
base_params = base_config.get("parameters", {})
if component_name in override_params:
merged_params[component_name] = self._deep_merge(
base_params, override_params[component_name]
)
else:
merged_params[component_name] = base_params
except FileNotFoundError:
print(
f"Warning: Component configuration file does not exist: {component_name}"
)
continue
return merged_params
def _load_yaml_file(self, file_path: Path) -> Dict[str, Any]:
"""Load YAML file"""
with open(file_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def _load_json_file(self, file_path: Path) -> Dict[str, Any]:
"""Load JSON file"""
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
def _load_config_file(self, file_path: Path) -> Dict[str, Any]:
"""Load configuration file based on file extension"""
suffix = file_path.suffix.lower()
if suffix == ".yaml" or suffix == ".yml":
return self._load_yaml_file(file_path)
elif suffix == ".json":
return self._load_json_file(file_path)
else:
raise ValueError(f"Unsupported configuration file format: {suffix}")
def load_config(config_file: str = "config/general_config.yaml") -> PipelineConfig:
"""
Configuration loading function
Args:
config_file: Configuration file path, can be full path or relative path
Function will automatically split into config directory path and general config file name
Returns:
PipelineConfig object
"""
# Convert configuration file path to Path object
config_path = Path(config_file)
# Get configuration directory path (directory where config file is located)
config_dir = str(config_path.parent)
# Get general configuration file name
general_config_file = config_path.name
# Create configuration loader and load all configurations
loader = ConfigLoader(config_dir)
return loader.load_all_configs(general_config_file)
def validate_config(config: PipelineConfig) -> bool:
"""
Validate configuration validity
Args:
config: PipelineConfig object
Returns:
Whether configuration is valid
"""
# Basic validation
if not config.output_dir:
print("Error: Output directory is not set")
return False
# Test case generation configuration validation
test_case_cfg = config.test_case_generation
if not test_case_cfg.get("input", {}).get("behaviors_file"):
print("Error: Harmful behavior file is not set")
return False
# No longer check image_dir, as image paths are now read directly from behavior data
if not test_case_cfg.get("attacks"):
print("Error: Attack methods are not set")
return False
# Response generation configuration validation
response_cfg = config.response_generation
if not response_cfg.get("models"):
print("Error: Models are not set")
return False
# Check if model configurations are complete
model_params = response_cfg.get("model_params", {})
for model_name in response_cfg.get("models", []):
if model_name not in model_params:
print(f"Warning: Model '{model_name}' configuration does not exist")
return True
# Global model configuration lookup function
def get_model_config(
model_name: str, config_dir: str = "config"
) -> Optional[Dict[str, Any]]:
"""
Global function: Find model configuration by model name
Args:
model_name: Model name
config_dir: Configuration directory path, defaults to "config"
Returns:
Model configuration dictionary, returns None if not found
"""
try:
loader = ConfigLoader(config_dir)
model_config = loader.load_model_config()
return loader._find_model_config(model_name, model_config)
except Exception as e:
print(f"Warning: Failed to find model configuration {model_name}: {e}")
return None