mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-05-08 20:34:52 +02:00
171 lines
5.1 KiB
Python
171 lines
5.1 KiB
Python
import torch
|
|
|
|
from ..attack import Attack
|
|
|
|
|
|
class MultiAttack(Attack):
|
|
r"""
|
|
MultiAttack is a class to attack a model with various attacks agains same images and labels.
|
|
|
|
Arguments:
|
|
model (nn.Module): model to attack.
|
|
attacks (list): list of attacks.
|
|
|
|
Examples::
|
|
>>> atk1 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
|
|
>>> atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
|
|
>>> atk = torchattacks.MultiAttack([atk1, atk2])
|
|
>>> adv_images = attack(images, labels)
|
|
|
|
"""
|
|
|
|
def __init__(self, attacks, verbose=False):
|
|
super().__init__("MultiAttack", attacks[0].model)
|
|
self.attacks = attacks
|
|
self.verbose = verbose
|
|
self.supported_mode = ["default"]
|
|
|
|
self.check_validity()
|
|
|
|
self._accumulate_multi_atk_records = False
|
|
self._multi_atk_records = [0.0]
|
|
|
|
def check_validity(self):
|
|
if len(self.attacks) < 2:
|
|
raise ValueError("More than two attacks should be given.")
|
|
|
|
ids = [id(attack.model) for attack in self.attacks]
|
|
if len(set(ids)) != 1:
|
|
raise ValueError(
|
|
"At least one of attacks is referencing a different model."
|
|
)
|
|
|
|
def forward(self, images, labels):
|
|
r"""
|
|
Overridden.
|
|
"""
|
|
batch_size = images.shape[0]
|
|
fails = torch.arange(batch_size).to(self.device)
|
|
final_images = images.clone().detach().to(self.device)
|
|
labels = labels.clone().detach().to(self.device)
|
|
|
|
multi_atk_records = [batch_size]
|
|
|
|
for _, attack in enumerate(self.attacks):
|
|
adv_images = attack(images[fails], labels[fails])
|
|
|
|
outputs = self.get_logits(adv_images)
|
|
_, pre = torch.max(outputs.data, 1)
|
|
|
|
corrects = pre == labels[fails]
|
|
wrongs = ~corrects
|
|
|
|
succeeds = torch.masked_select(fails, wrongs)
|
|
succeeds_of_fails = torch.masked_select(
|
|
torch.arange(fails.shape[0]).to(self.device), wrongs
|
|
)
|
|
|
|
final_images[succeeds] = adv_images[succeeds_of_fails]
|
|
|
|
fails = torch.masked_select(fails, corrects)
|
|
multi_atk_records.append(len(fails))
|
|
|
|
if len(fails) == 0:
|
|
break
|
|
|
|
if self.verbose:
|
|
print(self._return_sr_record(multi_atk_records))
|
|
|
|
if self._accumulate_multi_atk_records:
|
|
self._update_multi_atk_records(multi_atk_records)
|
|
|
|
return final_images
|
|
|
|
def _clear_multi_atk_records(self):
|
|
self._multi_atk_records = [0.0]
|
|
|
|
def _covert_to_success_rates(self, multi_atk_records):
|
|
sr = [
|
|
((1 - multi_atk_records[i] / multi_atk_records[0]) * 100)
|
|
for i in range(1, len(multi_atk_records))
|
|
]
|
|
return sr
|
|
|
|
def _return_sr_record(self, multi_atk_records):
|
|
sr = self._covert_to_success_rates(multi_atk_records)
|
|
return "Attack success rate: " + " | ".join(["%2.2f %%" % item for item in sr])
|
|
|
|
def _update_multi_atk_records(self, multi_atk_records):
|
|
for i, item in enumerate(multi_atk_records):
|
|
self._multi_atk_records[i] += item
|
|
|
|
def save(
|
|
self,
|
|
data_loader,
|
|
save_path=None,
|
|
verbose=True,
|
|
return_verbose=False,
|
|
save_predictions=False,
|
|
save_clean_images=False,
|
|
):
|
|
r"""
|
|
Overridden.
|
|
"""
|
|
self._clear_multi_atk_records()
|
|
prev_verbose = self.verbose
|
|
self.verbose = False
|
|
self._accumulate_multi_atk_records = True
|
|
|
|
for i, attack in enumerate(self.attacks):
|
|
self._multi_atk_records.append(0.0)
|
|
|
|
if return_verbose:
|
|
rob_acc, l2, elapsed_time = super().save(
|
|
data_loader,
|
|
save_path,
|
|
verbose,
|
|
return_verbose,
|
|
save_predictions,
|
|
save_clean_images,
|
|
)
|
|
sr = self._covert_to_success_rates(self._multi_atk_records)
|
|
elif verbose:
|
|
super().save(
|
|
data_loader,
|
|
save_path,
|
|
verbose,
|
|
return_verbose,
|
|
save_predictions,
|
|
save_clean_images,
|
|
)
|
|
sr = self._covert_to_success_rates(self._multi_atk_records)
|
|
else:
|
|
super().save(
|
|
data_loader,
|
|
save_path,
|
|
False,
|
|
False,
|
|
save_predictions,
|
|
save_clean_images,
|
|
)
|
|
|
|
self._clear_multi_atk_records()
|
|
self._accumulate_multi_atk_records = False
|
|
self.verbose = prev_verbose
|
|
|
|
if return_verbose:
|
|
return rob_acc, sr, l2, elapsed_time
|
|
|
|
def _save_print(self, progress, rob_acc, l2, elapsed_time, end):
|
|
r"""
|
|
Overridden.
|
|
"""
|
|
print(
|
|
"- Save progress: %2.2f %% / Robust accuracy: %2.2f %%"
|
|
% (progress, rob_acc)
|
|
+ " / "
|
|
+ self._return_sr_record(self._multi_atk_records)
|
|
+ " / L2: %1.5f (%2.3f it/s) \t" % (l2, elapsed_time),
|
|
end=end,
|
|
)
|