mirror of
https://github.com/jiaxiaojunQAQ/OmniSafeBench-MM.git
synced 2026-06-07 16:13:54 +02:00
183 lines
5.9 KiB
Python
183 lines
5.9 KiB
Python
import gc
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
# from llm_attacks import get_embedding_matrix, get_embeddings
|
|
|
|
|
|
|
|
def token_gradients(model, adv_suffix_tokens, adv_len,offset,context_grad):
|
|
|
|
embed_weights =model.llama_model.model.embed_tokens.weight
|
|
|
|
one_hot = torch.zeros(
|
|
adv_len,
|
|
embed_weights.shape[0],
|
|
device=model.device,
|
|
dtype=embed_weights.dtype
|
|
)
|
|
|
|
|
|
one_hot.scatter_(
|
|
1,
|
|
adv_suffix_tokens.unsqueeze(1),
|
|
torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embed_weights.dtype)
|
|
)
|
|
one_hot.requires_grad_()
|
|
input_embeds = (one_hot @ embed_weights).unsqueeze(0)
|
|
|
|
embs_grad=context_grad[:,offset:(offset+adv_len) , :].detach()
|
|
temp=(input_embeds*embs_grad).sum()#trick for caculating grad
|
|
temp.backward()
|
|
|
|
grad = one_hot.grad.clone()
|
|
grad = grad / grad.norm(dim=-1, keepdim=True)
|
|
|
|
return grad
|
|
|
|
def sample_control(control_toks, grad, batch_size, topk=256, temp=1, not_allowed_tokens=None):
|
|
|
|
|
|
top_indices = (-grad).topk(topk, dim=1).indices
|
|
control_toks = control_toks.to(grad.device)
|
|
|
|
original_control_toks = control_toks.repeat(batch_size, 1)
|
|
new_token_pos = torch.arange(
|
|
0,
|
|
len(control_toks),
|
|
len(control_toks) / batch_size,
|
|
device=grad.device
|
|
).type(torch.int64)
|
|
new_token_val = torch.gather(
|
|
top_indices[new_token_pos], 1,
|
|
torch.randint(0, topk, (batch_size, 1),
|
|
device=grad.device)
|
|
)
|
|
|
|
new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)
|
|
|
|
return new_control_toks
|
|
|
|
|
|
def get_filtered_cands(tokenizer, control_cand, filter_cand=True, curr_control=None):
|
|
cands, count = [], 0
|
|
for i in range(control_cand.shape[0]):
|
|
decoded_str = tokenizer.decode(control_cand[i], skip_special_tokens=True)
|
|
if filter_cand:
|
|
if decoded_str != curr_control and len(tokenizer(decoded_str, add_special_tokens=False).input_ids) == len(control_cand[i]):
|
|
cands.append(decoded_str)
|
|
else:
|
|
count += 1
|
|
else:
|
|
cands.append(decoded_str)
|
|
|
|
if filter_cand:
|
|
cands = cands + [cands[-1]] * (len(control_cand) - len(cands))
|
|
|
|
return cands
|
|
|
|
|
|
def get_logits(*, model, tokenizer, input_ids, control_slice, test_controls=None, return_ids=False, batch_size=512):
|
|
if isinstance(test_controls[0], str):
|
|
max_len = control_slice.stop - control_slice.start
|
|
test_ids = [
|
|
torch.tensor(tokenizer(control, add_special_tokens=False).input_ids[:max_len], device=model.device)
|
|
for control in test_controls
|
|
]
|
|
pad_tok = 0
|
|
while pad_tok in input_ids or any([pad_tok in ids for ids in test_ids]):
|
|
pad_tok += 1
|
|
nested_ids = torch.nested.nested_tensor(test_ids)
|
|
test_ids = torch.nested.to_padded_tensor(nested_ids, pad_tok, (len(test_ids), max_len))
|
|
else:
|
|
raise ValueError(f"test_controls must be a list of strings, got {type(test_controls)}")
|
|
|
|
if not(test_ids[0].shape[0] == control_slice.stop - control_slice.start):
|
|
raise ValueError((
|
|
f"test_controls must have shape "
|
|
f"(n, {control_slice.stop - control_slice.start}), "
|
|
f"got {test_ids.shape}"
|
|
))
|
|
|
|
locs = torch.arange(control_slice.start, control_slice.stop).repeat(test_ids.shape[0], 1).to(model.device)
|
|
ids = torch.scatter(
|
|
input_ids.unsqueeze(0).repeat(test_ids.shape[0], 1).to(model.device),
|
|
1,
|
|
locs,
|
|
test_ids
|
|
)
|
|
if pad_tok >= 0:
|
|
attn_mask = (ids != pad_tok).type(ids.dtype)
|
|
else:
|
|
attn_mask = None
|
|
|
|
if return_ids:
|
|
del locs, test_ids ; gc.collect()
|
|
return forward(model=model, input_ids=ids, attention_mask=attn_mask, batch_size=batch_size), ids
|
|
else:
|
|
del locs, test_ids
|
|
logits = forward(model=model, input_ids=ids, attention_mask=attn_mask, batch_size=batch_size)
|
|
del ids ; gc.collect()
|
|
return logits
|
|
|
|
|
|
def forward(*, model, input_ids, attention_mask, batch_size=512):
|
|
|
|
logits = []
|
|
for i in range(0, input_ids.shape[0], batch_size):
|
|
|
|
batch_input_ids = input_ids[i:i+batch_size]
|
|
if attention_mask is not None:
|
|
batch_attention_mask = attention_mask[i:i+batch_size]
|
|
else:
|
|
batch_attention_mask = None
|
|
|
|
logits.append(model(input_ids=batch_input_ids, attention_mask=batch_attention_mask).logits)
|
|
|
|
gc.collect()
|
|
|
|
del batch_input_ids, batch_attention_mask
|
|
|
|
return torch.cat(logits, dim=0)
|
|
|
|
def target_loss(logits, ids, target_slice):
|
|
crit = nn.CrossEntropyLoss(reduction='none')
|
|
loss_slice = slice(target_slice.start-1, target_slice.stop-1)
|
|
loss = crit(logits[:,loss_slice,:].transpose(1,2), ids[:,target_slice])
|
|
return loss.mean(dim=-1)
|
|
|
|
|
|
def load_model_and_tokenizer(model_path, tokenizer_path=None, device='cuda:0', **kwargs):
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.float16,
|
|
trust_remote_code=True,
|
|
**kwargs
|
|
).to(device).eval()
|
|
|
|
tokenizer_path = model_path if tokenizer_path is None else tokenizer_path
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
tokenizer_path,
|
|
trust_remote_code=True,
|
|
use_fast=False
|
|
)
|
|
|
|
if 'oasst-sft-6-llama-30b' in tokenizer_path:
|
|
tokenizer.bos_token_id = 1
|
|
tokenizer.unk_token_id = 0
|
|
if 'guanaco' in tokenizer_path:
|
|
tokenizer.eos_token_id = 2
|
|
tokenizer.unk_token_id = 0
|
|
if 'llama-2' in tokenizer_path:
|
|
tokenizer.pad_token = tokenizer.unk_token
|
|
tokenizer.padding_side = 'left'
|
|
if 'falcon' in tokenizer_path:
|
|
tokenizer.padding_side = 'left'
|
|
if not tokenizer.pad_token:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
return model, tokenizer |