mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-04-30 00:37:48 +02:00
359 lines
12 KiB
Python
359 lines
12 KiB
Python
import torch
|
|
from tqdm import tqdm
|
|
import random
|
|
from minigpt_utils import prompt_wrapper, generator
|
|
from torchvision.utils import save_image
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
import time
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.widgets import MultiCursor
|
|
import seaborn as sns
|
|
|
|
|
|
class Attacker:
|
|
|
|
def __init__(self, args, model, targets, device="cuda:0"):
|
|
|
|
self.args = args
|
|
self.model = model
|
|
self.device = device
|
|
|
|
self.targets = targets # targets that we want to promte likelihood
|
|
self.loss_buffer = []
|
|
self.num_targets = len(self.targets)
|
|
|
|
# freeze and set to eval model:
|
|
self.model.eval()
|
|
self.model.requires_grad_(False)
|
|
self.model.llama_tokenizer.padding_side = "right"
|
|
|
|
def get_vocabulary(self):
|
|
|
|
vocab_dicts = self.model.llama_tokenizer.get_vocab()
|
|
vocabs = vocab_dicts.keys()
|
|
|
|
single_token_vocabs = []
|
|
single_token_vocabs_embedding = []
|
|
single_token_id_to_vocab = dict()
|
|
single_token_vocab_to_id = dict()
|
|
|
|
cnt = 0
|
|
|
|
for item in vocabs:
|
|
tokens = self.model.llama_tokenizer(
|
|
item, return_tensors="pt", add_special_tokens=False
|
|
).input_ids.to(self.device)
|
|
if tokens.shape[1] == 1:
|
|
single_token_vocabs.append(item)
|
|
emb = self.model.llama_model.model.embed_tokens(tokens)
|
|
single_token_vocabs_embedding.append(emb)
|
|
|
|
single_token_id_to_vocab[cnt] = item
|
|
single_token_vocab_to_id[item] = cnt
|
|
|
|
cnt += 1
|
|
|
|
single_token_vocabs_embedding = torch.cat(
|
|
single_token_vocabs_embedding, dim=1
|
|
).squeeze()
|
|
|
|
self.vocabs = single_token_vocabs
|
|
self.embedding_matrix = single_token_vocabs_embedding.to(self.device)
|
|
self.id_to_vocab = single_token_id_to_vocab
|
|
self.vocab_to_id = single_token_vocab_to_id
|
|
|
|
def hotflip_attack(self, grad, token, increase_loss=False, num_candidates=1):
|
|
|
|
token_id = self.vocab_to_id[token]
|
|
token_emb = self.embedding_matrix[token_id] # embedding of current token
|
|
|
|
scores = ((self.embedding_matrix - token_emb) @ grad.T).squeeze(1)
|
|
|
|
if not increase_loss:
|
|
scores *= -1 # lower versus increase the class probability.
|
|
|
|
_, best_k_ids = torch.topk(scores, num_candidates)
|
|
return best_k_ids.detach().cpu().numpy()
|
|
|
|
def wrap_prompt(self, text_prompt_template, adv_prompt, queries, batch_size):
|
|
|
|
text_prompts = text_prompt_template % (adv_prompt + " | " + queries)
|
|
|
|
prompt = prompt_wrapper.Prompt(
|
|
model=self.model, text_prompts=[text_prompts], img_prompts=[[]]
|
|
) # Add adv_img here
|
|
|
|
prompt.context_embs[0] = prompt.context_embs[0].detach().requires_grad_(True)
|
|
prompt.context_embs = prompt.context_embs * batch_size
|
|
|
|
return prompt
|
|
|
|
def wrap_prompt_simple(self, text_prompt_template, adv_prompt, batch_size):
|
|
|
|
text_prompts = text_prompt_template % (
|
|
adv_prompt
|
|
) # insert the adversarial prompt
|
|
|
|
prompt = prompt_wrapper.Prompt(
|
|
model=self.model, text_prompts=[text_prompts], img_prompts=[[]]
|
|
)
|
|
|
|
prompt.context_embs[0] = prompt.context_embs[0].detach().requires_grad_(True)
|
|
prompt.context_embs = prompt.context_embs * batch_size
|
|
|
|
return prompt
|
|
|
|
def update_adv_prompt(self, adv_prompt_tokens, idx, new_token):
|
|
next_adv_prompt_tokens = deepcopy(adv_prompt_tokens)
|
|
next_adv_prompt_tokens[idx] = new_token
|
|
next_adv_prompt = " ".join(next_adv_prompt_tokens)
|
|
return next_adv_prompt_tokens, next_adv_prompt
|
|
|
|
def attack(self, text_prompt_template, offset, batch_size=8, num_iter=2000):
|
|
|
|
print(">>> batch_size: ", batch_size)
|
|
|
|
my_generator = generator.Generator(model=self.model, device=self.device)
|
|
|
|
self.get_vocabulary()
|
|
vocabs, embedding_matrix = self.vocabs, self.embedding_matrix
|
|
|
|
trigger_token_length = 32 # equivalent to
|
|
adv_prompt_tokens = random.sample(vocabs, trigger_token_length)
|
|
adv_prompt = " ".join(adv_prompt_tokens)
|
|
|
|
st = time.time()
|
|
|
|
for t in tqdm(range(num_iter + 1)):
|
|
|
|
for token_to_flip in range(
|
|
0, trigger_token_length
|
|
): # for each token in the trigger
|
|
|
|
batch_targets = random.sample(self.targets, batch_size)
|
|
|
|
prompt = self.wrap_prompt_simple(
|
|
text_prompt_template, adv_prompt, batch_size
|
|
)
|
|
|
|
target_loss = self.attack_loss(prompt, batch_targets)
|
|
loss = target_loss # to minimize
|
|
loss.backward()
|
|
|
|
print("[adv_prompt]", adv_prompt)
|
|
print("target_loss: %f" % (target_loss.item()))
|
|
self.loss_buffer.append(target_loss.item())
|
|
|
|
tokens_grad = prompt.context_embs[0].grad[:, token_to_flip + offset, :]
|
|
candidates = self.hotflip_attack(
|
|
tokens_grad,
|
|
adv_prompt_tokens[token_to_flip],
|
|
increase_loss=False,
|
|
num_candidates=self.args.n_candidates,
|
|
)
|
|
|
|
self.model.zero_grad()
|
|
|
|
# try all the candidates and pick the best
|
|
# comparing candidates does not require gradient computation
|
|
with torch.no_grad():
|
|
curr_best_loss = 999999
|
|
curr_best_trigger_tokens = None
|
|
curr_best_trigger = None
|
|
|
|
for cand in candidates:
|
|
next_adv_prompt_tokens, next_adv_prompt = (
|
|
self.update_adv_prompt(
|
|
adv_prompt_tokens, token_to_flip, self.id_to_vocab[cand]
|
|
)
|
|
)
|
|
prompt = self.wrap_prompt_simple(
|
|
text_prompt_template, next_adv_prompt, batch_size
|
|
)
|
|
|
|
next_target_loss = self.attack_loss(prompt, batch_targets)
|
|
curr_loss = next_target_loss # to minimize
|
|
|
|
if curr_loss < curr_best_loss:
|
|
curr_best_loss = curr_loss
|
|
curr_best_trigger_tokens = next_adv_prompt_tokens
|
|
curr_best_trigger = next_adv_prompt
|
|
|
|
# Update overall best if the best current candidate is better
|
|
if curr_best_loss < loss:
|
|
adv_prompt_tokens = curr_best_trigger_tokens
|
|
adv_prompt = curr_best_trigger
|
|
print("(update: %f minutes)" % ((time.time() - st) / 60))
|
|
|
|
self.plot_loss()
|
|
|
|
if True:
|
|
print("######### Output - Iter = %d ##########" % t)
|
|
prompt = self.wrap_prompt_simple(
|
|
text_prompt_template, adv_prompt, batch_size
|
|
)
|
|
with torch.no_grad():
|
|
response, _ = my_generator.generate(prompt)
|
|
|
|
print("[prompt]", prompt.text_prompts[0])
|
|
print(">>>", response)
|
|
|
|
return adv_prompt
|
|
|
|
def plot_loss(self):
|
|
|
|
sns.set_theme()
|
|
|
|
num_iters = len(self.loss_buffer)
|
|
|
|
num_iters = min(num_iters, 5000)
|
|
|
|
x_ticks = list(range(0, num_iters))
|
|
|
|
# Plot and label the training and validation loss values
|
|
plt.plot(x_ticks, self.loss_buffer[:num_iters], label="Target Loss")
|
|
|
|
# Add in a title and axes labels
|
|
plt.title("Loss Plot")
|
|
plt.xlabel("Iters")
|
|
plt.ylabel("Loss")
|
|
|
|
# Display the plot
|
|
plt.legend(loc="best")
|
|
plt.savefig("%s/loss_curve.png" % (self.args.save_dir))
|
|
plt.clf()
|
|
|
|
torch.save(self.loss_buffer, "%s/loss" % (self.args.save_dir))
|
|
|
|
def attack_loss(self, prompts, targets):
|
|
|
|
context_embs = prompts.context_embs
|
|
assert len(context_embs) == len(targets), (
|
|
"Unmathced batch size of prompts and targets, the length of context_embs is %d, the length of targets is %d"
|
|
% (len(context_embs), len(targets))
|
|
)
|
|
|
|
batch_size = len(targets)
|
|
|
|
self.model.llama_tokenizer.padding_side = "right"
|
|
|
|
to_regress_tokens = self.model.llama_tokenizer(
|
|
targets,
|
|
return_tensors="pt",
|
|
padding="longest",
|
|
truncation=True,
|
|
max_length=self.model.max_txt_len,
|
|
add_special_tokens=False,
|
|
).to(self.device)
|
|
to_regress_embs = self.model.llama_model.model.embed_tokens(
|
|
to_regress_tokens.input_ids
|
|
)
|
|
|
|
bos = (
|
|
torch.ones(
|
|
[1, 1],
|
|
dtype=to_regress_tokens.input_ids.dtype,
|
|
device=to_regress_tokens.input_ids.device,
|
|
)
|
|
* self.model.llama_tokenizer.bos_token_id
|
|
)
|
|
bos_embs = self.model.llama_model.model.embed_tokens(bos)
|
|
|
|
pad = (
|
|
torch.ones(
|
|
[1, 1],
|
|
dtype=to_regress_tokens.input_ids.dtype,
|
|
device=to_regress_tokens.input_ids.device,
|
|
)
|
|
* self.model.llama_tokenizer.pad_token_id
|
|
)
|
|
pad_embs = self.model.llama_model.model.embed_tokens(pad)
|
|
|
|
T = to_regress_tokens.input_ids.masked_fill(
|
|
to_regress_tokens.input_ids == self.model.llama_tokenizer.pad_token_id, -100
|
|
)
|
|
|
|
pos_padding = torch.argmin(
|
|
T, dim=1
|
|
) # a simple trick to find the start position of padding
|
|
|
|
input_embs = []
|
|
targets_mask = []
|
|
|
|
target_tokens_length = []
|
|
context_tokens_length = []
|
|
seq_tokens_length = []
|
|
|
|
for i in range(batch_size):
|
|
|
|
pos = int(pos_padding[i])
|
|
if T[i][pos] == -100:
|
|
target_length = pos
|
|
else:
|
|
target_length = T.shape[1]
|
|
|
|
targets_mask.append(T[i : i + 1, :target_length])
|
|
input_embs.append(
|
|
to_regress_embs[i : i + 1, :target_length]
|
|
) # omit the padding tokens
|
|
|
|
context_length = context_embs[i].shape[1]
|
|
seq_length = target_length + context_length
|
|
|
|
target_tokens_length.append(target_length)
|
|
context_tokens_length.append(context_length)
|
|
seq_tokens_length.append(seq_length)
|
|
|
|
max_length = max(seq_tokens_length)
|
|
|
|
attention_mask = []
|
|
|
|
for i in range(batch_size):
|
|
# masked out the context from loss computation
|
|
context_mask = (
|
|
torch.ones([1, context_tokens_length[i] + 1], dtype=torch.long)
|
|
.to(self.device)
|
|
.fill_(-100) # plus one for bos
|
|
)
|
|
|
|
# padding to align the length
|
|
num_to_pad = max_length - seq_tokens_length[i]
|
|
padding_mask = (
|
|
torch.ones([1, num_to_pad], dtype=torch.long)
|
|
.to(self.device)
|
|
.fill_(-100)
|
|
)
|
|
|
|
targets_mask[i] = torch.cat(
|
|
[context_mask, targets_mask[i], padding_mask], dim=1
|
|
)
|
|
input_embs[i] = torch.cat(
|
|
[
|
|
bos_embs,
|
|
context_embs[i],
|
|
input_embs[i],
|
|
pad_embs.repeat(1, num_to_pad, 1),
|
|
],
|
|
dim=1,
|
|
)
|
|
attention_mask.append(
|
|
torch.LongTensor([[1] * (1 + seq_tokens_length[i]) + [0] * num_to_pad])
|
|
)
|
|
|
|
targets = torch.cat(targets_mask, dim=0).to(self.device)
|
|
inputs_embs = torch.cat(input_embs, dim=0).to(self.device)
|
|
attention_mask = torch.cat(attention_mask, dim=0).to(self.device)
|
|
|
|
outputs = self.model.llama_model(
|
|
inputs_embeds=inputs_embs,
|
|
attention_mask=attention_mask,
|
|
return_dict=True,
|
|
labels=targets,
|
|
)
|
|
loss = outputs.loss
|
|
|
|
return loss
|