mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-02-13 02:02:45 +00:00
102 lines
3.0 KiB
Python
102 lines
3.0 KiB
Python
# attacks/umk/attack.py
|
|
from __future__ import annotations
|
|
from typing import Any, Dict, Iterable, Tuple, Optional, Union
|
|
|
|
import csv
|
|
import random
|
|
from torchvision.utils import save_image
|
|
from PIL import Image
|
|
|
|
from .minigpt_utils import vlm_attacker, prompt_wrapper
|
|
from dataclasses import dataclass
|
|
from core.base_classes import BaseAttack, TestCase
|
|
|
|
|
|
@dataclass
|
|
class UMKConfig:
|
|
n_iters: int = 5000
|
|
eps: int = 32
|
|
alpha: float = 1.0
|
|
corpus: str = "attacks/data/harmful_behaviors.csv"
|
|
prior_image: str = ""
|
|
batch_size: int = 8
|
|
target_model: str = None
|
|
|
|
|
|
class UMKAttack(BaseAttack):
|
|
CONFIG_CLASS = UMKConfig
|
|
|
|
def __init__(self, config: Dict[str, Any] = None, output_image_dir: str = None):
|
|
super().__init__(config, output_image_dir)
|
|
import importlib
|
|
|
|
self.target_model_name = str(getattr(self.cfg, "target_model", ""))
|
|
module_path, cls_name = self.target_model_name.rsplit(".", 1)
|
|
mod = importlib.import_module(f"multimodalmodels.{module_path}")
|
|
self.model = getattr(mod, cls_name)()
|
|
|
|
def generate_test_case(
|
|
self,
|
|
original_prompt: str,
|
|
image_path: str,
|
|
case_id: str,
|
|
**kwargs,
|
|
) -> TestCase:
|
|
|
|
images_dir = self.output_image_dir
|
|
|
|
file = open(self.cfg.corpus, "r")
|
|
|
|
data = list(csv.reader(file))
|
|
file.close()
|
|
queries = []
|
|
targets = []
|
|
num = len(data)
|
|
|
|
for i in range(1, num):
|
|
queries.append(data[i][0])
|
|
targets.append(data[i][1])
|
|
|
|
all_numbers = list(range(len(queries)))
|
|
|
|
selected_numbers = random.sample(all_numbers, 66)
|
|
|
|
queries = [queries[index] for index in selected_numbers]
|
|
targets = [targets[index] + ":" for index in selected_numbers]
|
|
|
|
my_attacker = vlm_attacker.Attacker(
|
|
self.cfg,
|
|
self.model,
|
|
queries,
|
|
targets,
|
|
device=self.model.device,
|
|
is_rtp=False,
|
|
)
|
|
|
|
img_path = getattr(self.cfg, "prior_image", "")
|
|
img = Image.open(img_path).convert("RGB")
|
|
img = self.model.vis_processor(img).unsqueeze(0).to(self.model.device)
|
|
text_prompt_template = prompt_wrapper.minigpt4_chatbot_prompt
|
|
adv_img_tensor, text_suffix = my_attacker.attack_unconstrained(
|
|
text_prompt_template,
|
|
img=img,
|
|
batch_size=self.cfg.batch_size,
|
|
num_iter=self.cfg.n_iters,
|
|
alpha=self.cfg.alpha / 255,
|
|
)
|
|
if adv_img_tensor.dim() == 3:
|
|
adv_t_save = adv_img_tensor.unsqueeze(0)
|
|
else:
|
|
adv_t_save = adv_img_tensor
|
|
save_path = images_dir / f"umk_{case_id}.jpg"
|
|
save_image(adv_t_save, str(save_path))
|
|
|
|
jailbreak_prompt = original_prompt + text_suffix
|
|
return self.create_test_case(
|
|
case_id=case_id,
|
|
jailbreak_prompt=jailbreak_prompt,
|
|
jailbreak_image_path=str(save_path),
|
|
original_prompt=original_prompt,
|
|
original_image_path=str(image_path),
|
|
)
|