mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-04-29 16:27:49 +02:00
280 lines
9.7 KiB
Python
280 lines
9.7 KiB
Python
import torch
|
|
from tqdm import tqdm
|
|
import random
|
|
from minigpt_utils import prompt_wrapper, generator
|
|
from torchvision.utils import save_image
|
|
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
|
|
def normalize(images):
|
|
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
|
|
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
|
|
images = images - mean[None, :, None, None]
|
|
images = images / std[None, :, None, None]
|
|
return images
|
|
|
|
|
|
def denormalize(images):
|
|
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
|
|
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
|
|
images = images * std[None, :, None, None]
|
|
images = images + mean[None, :, None, None]
|
|
return images
|
|
|
|
|
|
class Attacker:
|
|
|
|
def __init__(self, args, model, targets, device='cuda:0', is_rtp=False):
|
|
|
|
self.args = args
|
|
self.model = model
|
|
self.device = device
|
|
self.is_rtp = is_rtp
|
|
|
|
self.targets = targets
|
|
self.num_targets = len(targets)
|
|
|
|
self.loss_buffer = []
|
|
|
|
# freeze and set to eval model:
|
|
self.model.eval()
|
|
self.model.requires_grad_(False)
|
|
|
|
def attack_unconstrained(self, text_prompt, img, batch_size=8, num_iter=2000, alpha=1 / 255):
|
|
|
|
print('>>> batch_size:', batch_size)
|
|
|
|
my_generator = generator.Generator(model=self.model)
|
|
|
|
adv_noise = torch.rand_like(img).to(self.device) # [0,1]
|
|
adv_noise.requires_grad_(True)
|
|
adv_noise.retain_grad()
|
|
|
|
for t in tqdm(range(num_iter + 1)):
|
|
|
|
batch_targets = random.sample(self.targets, batch_size)
|
|
text_prompts = [text_prompt] * batch_size
|
|
|
|
x_adv = normalize(adv_noise)
|
|
|
|
prompt = prompt_wrapper.Prompt(model=self.model, text_prompts=text_prompts, img_prompts=[[x_adv]])
|
|
prompt.img_embs = prompt.img_embs * batch_size
|
|
prompt.update_context_embs()
|
|
|
|
target_loss = self.attack_loss(prompt, batch_targets)
|
|
target_loss.backward()
|
|
|
|
adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(0, 1)
|
|
adv_noise.grad.zero_()
|
|
self.model.zero_grad()
|
|
|
|
self.loss_buffer.append(target_loss.item())
|
|
|
|
print("target_loss: %f" % (
|
|
target_loss.item())
|
|
)
|
|
|
|
if t % 20 == 0:
|
|
self.plot_loss()
|
|
|
|
if t % 100 == 0:
|
|
print('######### Output - Iter = %d ##########' % t)
|
|
x_adv = normalize(adv_noise)
|
|
prompt.update_img_prompts([[x_adv]])
|
|
prompt.img_embs = prompt.img_embs * batch_size
|
|
prompt.update_context_embs()
|
|
with torch.no_grad():
|
|
response, _ = my_generator.generate(prompt)
|
|
print('>>>', response)
|
|
|
|
adv_img_prompt = denormalize(x_adv).detach().cpu()
|
|
adv_img_prompt = adv_img_prompt.squeeze(0)
|
|
save_image(adv_img_prompt, '%s/bad_prompt_temp_%d.bmp' % (self.args.save_dir, t))
|
|
|
|
return adv_img_prompt
|
|
|
|
def attack_constrained(self, text_prompt, img, batch_size=8, num_iter=2000, alpha=1 / 255, epsilon=128 / 255):
|
|
|
|
print('>>> batch_size:', batch_size)
|
|
|
|
my_generator = generator.Generator(model=self.model)
|
|
|
|
adv_noise = torch.rand_like(img).to(self.device) * 2 * epsilon - epsilon
|
|
x = denormalize(img).clone().to(self.device)
|
|
adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
|
|
|
|
adv_noise.requires_grad_(True)
|
|
adv_noise.retain_grad()
|
|
|
|
for t in tqdm(range(num_iter + 1)):
|
|
|
|
batch_targets = random.sample(self.targets, batch_size)
|
|
|
|
text_prompts = [text_prompt] * batch_size
|
|
|
|
x_adv = x + adv_noise
|
|
x_adv = normalize(x_adv)
|
|
|
|
prompt = prompt_wrapper.Prompt(model=self.model, text_prompts=text_prompts, img_prompts=[[x_adv]])
|
|
prompt.img_embs = prompt.img_embs * batch_size
|
|
prompt.update_context_embs()
|
|
|
|
target_loss = self.attack_loss(prompt, batch_targets)
|
|
target_loss.backward()
|
|
|
|
adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(-epsilon, epsilon)
|
|
adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
|
|
adv_noise.grad.zero_()
|
|
self.model.zero_grad()
|
|
|
|
self.loss_buffer.append(target_loss.item())
|
|
|
|
print("target_loss: %f" % (
|
|
target_loss.item())
|
|
)
|
|
|
|
if t % 20 == 0:
|
|
self.plot_loss()
|
|
|
|
if t % 100 == 0:
|
|
print('######### Output - Iter = %d ##########' % t)
|
|
x_adv = x + adv_noise
|
|
x_adv = normalize(x_adv)
|
|
prompt.update_img_prompts([[x_adv]])
|
|
prompt.img_embs = prompt.img_embs * batch_size
|
|
prompt.update_context_embs()
|
|
with torch.no_grad():
|
|
response, _ = my_generator.generate(prompt)
|
|
print('>>>', response)
|
|
|
|
adv_img_prompt = denormalize(x_adv).detach().cpu()
|
|
adv_img_prompt = adv_img_prompt.squeeze(0)
|
|
save_image(adv_img_prompt, '%s/bad_prompt_temp_%d.bmp' % (self.args.save_dir, t))
|
|
|
|
return adv_img_prompt, target_loss
|
|
|
|
def plot_loss(self):
|
|
|
|
sns.set_theme()
|
|
num_iters = len(self.loss_buffer)
|
|
|
|
x_ticks = list(range(0, num_iters))
|
|
|
|
# Plot and label the training and validation loss values
|
|
plt.plot(x_ticks, self.loss_buffer, label='Target Loss')
|
|
|
|
# Add in a title and axes labels
|
|
plt.title('Loss Plot')
|
|
plt.xlabel('Iters')
|
|
plt.ylabel('Loss')
|
|
|
|
# Display the plot
|
|
plt.legend(loc='best')
|
|
plt.savefig('%s/loss_curve.png' % (self.args.save_dir))
|
|
plt.clf()
|
|
|
|
torch.save(self.loss_buffer, '%s/loss' % (self.args.save_dir))
|
|
|
|
def attack_loss(self, prompts, targets):
|
|
|
|
context_embs = prompts.context_embs
|
|
|
|
if len(context_embs) == 1:
|
|
context_embs = context_embs * len(targets) # expand to fit the batch_size
|
|
|
|
assert len(context_embs) == len(
|
|
targets), f"Unmathced batch size of prompts and targets {len(context_embs)} != {len(targets)}"
|
|
|
|
batch_size = len(targets)
|
|
self.model.llama_tokenizer.padding_side = "right"
|
|
|
|
to_regress_tokens = self.model.llama_tokenizer(
|
|
targets,
|
|
return_tensors="pt",
|
|
padding="longest",
|
|
truncation=True,
|
|
max_length=self.model.max_txt_len,
|
|
add_special_tokens=False
|
|
).to(self.device)
|
|
to_regress_embs = self.model.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
|
|
|
bos = torch.ones([1, 1],
|
|
dtype=to_regress_tokens.input_ids.dtype,
|
|
device=to_regress_tokens.input_ids.device) * self.model.llama_tokenizer.bos_token_id
|
|
bos_embs = self.model.llama_model.model.embed_tokens(bos)
|
|
|
|
pad = torch.ones([1, 1],
|
|
dtype=to_regress_tokens.input_ids.dtype,
|
|
device=to_regress_tokens.input_ids.device) * self.model.llama_tokenizer.pad_token_id
|
|
pad_embs = self.model.llama_model.model.embed_tokens(pad)
|
|
|
|
T = to_regress_tokens.input_ids.masked_fill(
|
|
to_regress_tokens.input_ids == self.model.llama_tokenizer.pad_token_id, -100
|
|
)
|
|
|
|
pos_padding = torch.argmin(T, dim=1) # a simple trick to find the start position of padding
|
|
|
|
input_embs = []
|
|
targets_mask = []
|
|
|
|
target_tokens_length = []
|
|
context_tokens_length = []
|
|
seq_tokens_length = []
|
|
|
|
for i in range(batch_size):
|
|
|
|
pos = int(pos_padding[i])
|
|
if T[i][pos] == -100:
|
|
target_length = pos
|
|
else:
|
|
target_length = T.shape[1]
|
|
|
|
targets_mask.append(T[i:i + 1, :target_length])
|
|
input_embs.append(to_regress_embs[i:i + 1, :target_length]) # omit the padding tokens
|
|
|
|
context_length = context_embs[i].shape[1]
|
|
seq_length = target_length + context_length
|
|
|
|
target_tokens_length.append(target_length)
|
|
context_tokens_length.append(context_length)
|
|
seq_tokens_length.append(seq_length)
|
|
|
|
max_length = max(seq_tokens_length)
|
|
|
|
attention_mask = []
|
|
|
|
for i in range(batch_size):
|
|
# masked out the context from loss computation
|
|
context_mask = (
|
|
torch.ones([1, context_tokens_length[i] + 1],
|
|
dtype=torch.long).to(self.device).fill_(-100) # plus one for bos
|
|
)
|
|
|
|
# padding to align the length
|
|
num_to_pad = max_length - seq_tokens_length[i]
|
|
padding_mask = (
|
|
torch.ones([1, num_to_pad],
|
|
dtype=torch.long).to(self.device).fill_(-100)
|
|
)
|
|
|
|
targets_mask[i] = torch.cat([context_mask, targets_mask[i], padding_mask], dim=1)
|
|
input_embs[i] = torch.cat([bos_embs, context_embs[i], input_embs[i],
|
|
pad_embs.repeat(1, num_to_pad, 1)], dim=1)
|
|
attention_mask.append(torch.LongTensor([[1] * (1 + seq_tokens_length[i]) + [0] * num_to_pad]))
|
|
|
|
targets = torch.cat(targets_mask, dim=0).to(self.device)
|
|
inputs_embs = torch.cat(input_embs, dim=0).to(self.device)
|
|
attention_mask = torch.cat(attention_mask, dim=0).to(self.device)
|
|
|
|
outputs = self.model.llama_model(
|
|
inputs_embeds=inputs_embs,
|
|
attention_mask=attention_mask,
|
|
return_dict=True,
|
|
labels=targets,
|
|
)
|
|
loss = outputs.loss
|
|
|
|
return loss
|