mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-03-11 14:36:07 +00:00
test: add test for per-attack image directories and improve error handling
- Add test_generate_test_cases.py to verify each attack uses its own image directory - Improve error messages in unified_registry for failed imports with more context - Add null checks in generate_test_cases.py for attack initialization failures - Reduce excessive max_tokens in judge_utils from 130000 to 10000 for efficiency - Add missing ThreadPoolExecutor import in qr/attack.py - Fix missing newline in pyproject.toml
This commit is contained in:
@@ -14,6 +14,7 @@ from PIL import Image, ImageFont, ImageDraw
|
||||
import torch
|
||||
import threading
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from core.data_formats import TestCase
|
||||
from core.base_classes import BaseAttack
|
||||
|
||||
@@ -154,8 +154,10 @@ class UnifiedRegistry:
|
||||
)
|
||||
return cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
self.logger.debug(
|
||||
f"Unable to import {type_name} '{name}' from mapping: {e}"
|
||||
self.logger.warning(
|
||||
f"Failed to import {type_name} '{name}': {e}. "
|
||||
f"This usually means missing dependencies. "
|
||||
f"Module: {module_path}, Class: {class_name}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
|
||||
@@ -491,15 +491,15 @@ Your judgment must be more than a simple formula. You are to weigh the evidence
|
||||
|
||||
# Generate evaluation responses - use adapter function
|
||||
evaluation_harmfulness_response = _call_model_generate(
|
||||
judge_client, evaluate_harmfulness_messages, max_tokens=130000, temperature=0
|
||||
judge_client, evaluate_harmfulness_messages, max_tokens=10000, temperature=0
|
||||
)
|
||||
|
||||
evaluation_alignment_response = _call_model_generate(
|
||||
judge_client, evaluate_alignment_messages, max_tokens=130000, temperature=0
|
||||
judge_client, evaluate_alignment_messages, max_tokens=10000, temperature=0
|
||||
)
|
||||
|
||||
evaluation_detail_response = _call_model_generate(
|
||||
judge_client, evaluate_detail_messages, max_tokens=130000, temperature=0
|
||||
judge_client, evaluate_detail_messages, max_tokens=10000, temperature=0
|
||||
)
|
||||
|
||||
# Extract scores and reasons
|
||||
@@ -527,7 +527,7 @@ Your judgment must be more than a simple formula. You are to weigh the evidence
|
||||
]
|
||||
|
||||
final_evaluation_response = _call_model_generate(
|
||||
judge_client, final_message, max_tokens=130000, temperature=0
|
||||
judge_client, final_message, max_tokens=10000, temperature=0
|
||||
)
|
||||
final_score, final_reason = extract_evaluation_score(final_evaluation_response)
|
||||
|
||||
|
||||
@@ -96,6 +96,14 @@ class TestCaseGenerator(BasePipeline):
|
||||
attack_name, attack_config, output_image_dir=str(image_save_dir)
|
||||
)
|
||||
|
||||
if attack is None:
|
||||
self.logger.error(
|
||||
f"Attack method '{attack_name}' failed to initialize. "
|
||||
f"This is likely due to missing dependencies or import errors. "
|
||||
f"Check logs above for details. Skipping this attack."
|
||||
)
|
||||
return None
|
||||
|
||||
test_case = attack.generate_test_case(
|
||||
original_prompt,
|
||||
image_path,
|
||||
@@ -140,6 +148,14 @@ class TestCaseGenerator(BasePipeline):
|
||||
attack_name, attack_config, output_image_dir=str(image_save_dir)
|
||||
)
|
||||
|
||||
if attack is None:
|
||||
self.logger.error(
|
||||
f"Attack method '{attack_name}' failed to initialize. "
|
||||
f"This is likely due to missing dependencies or import errors. "
|
||||
f"Check logs above for details. Skipping this attack."
|
||||
)
|
||||
return []
|
||||
|
||||
# Unified resource policy (single source of truth)
|
||||
policy = policy_for_test_case_generation(
|
||||
attack_config, default_max_workers=self.config.max_workers
|
||||
@@ -359,6 +375,7 @@ class TestCaseGenerator(BasePipeline):
|
||||
attack_name,
|
||||
attack_config,
|
||||
task_id,
|
||||
image_save_dir,
|
||||
output_file_path,
|
||||
expected_count,
|
||||
)
|
||||
@@ -415,6 +432,7 @@ class TestCaseGenerator(BasePipeline):
|
||||
attack_name,
|
||||
attack_config,
|
||||
task_id,
|
||||
image_save_dir,
|
||||
output_file_path,
|
||||
expected_count,
|
||||
) in pending_attacks_to_process:
|
||||
|
||||
@@ -96,4 +96,4 @@ explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cu118" }
|
||||
torchvision = { index = "pytorch-cu118" }
|
||||
torchvision = { index = "pytorch-cu118" }
|
||||
|
||||
53
tests/test_generate_test_cases.py
Normal file
53
tests/test_generate_test_cases.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.data_formats import PipelineConfig
|
||||
from pipeline.generate_test_cases import TestCaseGenerator as CaseGenerator
|
||||
|
||||
|
||||
def test_run_uses_per_attack_image_dir(tmp_path):
|
||||
config = PipelineConfig(
|
||||
system={"output_dir": str(tmp_path)},
|
||||
test_case_generation={
|
||||
"input": {"behaviors_file": "unused_in_test.json"},
|
||||
"attacks": ["jood", "hades"],
|
||||
"attack_params": {"jood": {}, "hades": {}},
|
||||
},
|
||||
batch_size=1,
|
||||
)
|
||||
generator = CaseGenerator(config)
|
||||
|
||||
behaviors = [{"case_id": 1, "original_prompt": "p", "image_path": "img.jpg"}]
|
||||
captured_image_dirs = {}
|
||||
|
||||
def fake_generate_filename(stage_name, **context):
|
||||
attack_name = context["attack_name"]
|
||||
image_dir = tmp_path / attack_name / "images"
|
||||
output_file = tmp_path / attack_name / "test_cases.jsonl"
|
||||
return image_dir, output_file
|
||||
|
||||
def fake_generate_single_attack_test_cases(
|
||||
attack_name,
|
||||
attack_config,
|
||||
behaviors,
|
||||
batch_size,
|
||||
output_file_path,
|
||||
image_save_dir,
|
||||
):
|
||||
captured_image_dirs[attack_name] = Path(image_save_dir)
|
||||
return []
|
||||
|
||||
with patch.object(generator, "validate_config", return_value=True), patch.object(
|
||||
generator, "load_behaviors", return_value=behaviors
|
||||
), patch.object(generator, "_calculate_expected_test_cases", return_value=1), patch.object(
|
||||
generator, "load_results", return_value=[]
|
||||
), patch.object(generator, "_generate_filename", side_effect=fake_generate_filename), patch.object(
|
||||
generator,
|
||||
"generate_single_attack_test_cases",
|
||||
side_effect=fake_generate_single_attack_test_cases,
|
||||
):
|
||||
generator.run()
|
||||
|
||||
assert captured_image_dirs["jood"] == tmp_path / "jood" / "images"
|
||||
assert captured_image_dirs["hades"] == tmp_path / "hades" / "images"
|
||||
assert captured_image_dirs["jood"] != captured_image_dirs["hades"]
|
||||
Reference in New Issue
Block a user