Files
UDora/udora/attack.py
javyduck 0d259fc3aa init_v1
2025-06-24 01:50:12 +00:00

1107 lines
48 KiB
Python

"""
UDora: A Unified Red Teaming Framework against LLM Agents by Dynamically Hijacking Their Own Reasoning
This module implements the core UDora attack algorithm that dynamically optimizes adversarial strings
by leveraging LLM agents' own reasoning processes to trigger targeted malicious actions.
Key features:
- Dynamic position identification for noise insertion in reasoning traces
- Weighted interval scheduling for optimal target placement
- Support for both sequential and joint optimization modes
- Fallback mechanisms for robust attack execution
"""
import os
import copy
import gc
import logging
import re
from dataclasses import dataclass
from tqdm import tqdm
from typing import List, Optional, Union
import math
import torch
import transformers
from torch import Tensor
from transformers import set_seed
from .utils import INIT_CHARS, find_executable_batch_size, get_nonascii_toks, mellowmax, check_type, combine_with_overlap
from .scheduling import weighted_interval_scheduling, filter_intervals_by_sequential_mode, build_final_token_sequence
from .datasets import check_success_condition
from .text_processing import build_target_intervals, count_matched_locations, format_interval_debug_info
from .loss import compute_udora_loss, compute_cross_entropy_loss, apply_mellowmax_loss
from .readable import create_readable_optimizer
logger = logging.getLogger("UDora")
if not logger.hasHandlers():
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
@dataclass
class UDoraConfig:
"""Configuration class for UDora attack parameters.
Core Attack Parameters:
num_steps: Maximum number of optimization iterations
optim_str_init: Initial adversarial string(s) to optimize
search_width: Number of candidate sequences to evaluate per iteration
batch_size: Batch size for processing candidates (None for auto-sizing)
topk: Top-k gradient directions to consider for token replacement
n_replace: Number of token positions to update per sequence
Advanced Optimization:
buffer_size: Size of attack buffer for maintaining best candidates
use_mellowmax: Whether to use mellowmax instead of cross-entropy loss
mellowmax_alpha: Alpha parameter for mellowmax function
early_stop: Stop optimization when target action is triggered
Model Interface:
use_prefix_cache: Cache prefix computations for efficiency
allow_non_ascii: Allow non-ASCII tokens in optimization
filter_ids: Filter token sequences that change after retokenization
add_space_before_target: Add space before target text during matching
max_new_tokens: Maximum tokens to generate during inference
Experimental Features:
minimize_reward: Minimize instead of maximize target probability (inverse attack)
sequential: Use sequential vs joint optimization mode
weight: Exponential weighting factor for multi-target optimization (weight^position)
num_location: Number of target insertion locations
prefix_update_frequency: How often to update reasoning context
readable: Optimize for readable adversarial strings using perplexity guidance
before_negative: Stop interval collection when encountering negative responses like "cannot"
Utility:
seed: Random seed for reproducibility
verbosity: Logging level ("DEBUG", "INFO", "WARNING", "ERROR")
dataset: Dataset name for success conditions. Predefined: "webshop", "injecagent", "agentharm".
Custom dataset names are also supported and use default text matching behavior.
"""
# Core attack parameters
num_steps: int = 250
optim_str_init: Union[str, List[str]] = "x x x x x x x x x x x x x x x x x x x x x x x x x"
search_width: int = 512
batch_size: Optional[int] = None
topk: int = 256
n_replace: int = 1
# Advanced optimization
buffer_size: int = 0
use_mellowmax: bool = False
mellowmax_alpha: float = 1.0
early_stop: bool = False
# Model interface
use_prefix_cache: bool = True
allow_non_ascii: bool = False
filter_ids: bool = True
add_space_before_target: bool = False
max_new_tokens: int = 300
# Experimental features (beta)
minimize_reward: bool = False # beta, => minimizing the appearance of some specific noise
sequential: bool = True
weight: float = 1.0 # Exponential weighting factor for multi-target loss (weight^position)
num_location: int = 2
prefix_update_frequency: int = 1
readable: bool = False # beta, => optimize for readable adversarial strings using perplexity (requires readable optim_str_init)
before_negative: bool = False # beta, => stop interval collection when encountering negative responses
# Utility
seed: Optional[int] = None
verbosity: str = "INFO"
dataset: str = "webshop" # Dataset name for different success conditions: "webshop", "injecagent", "agentharm"
@dataclass
class UDoraResult:
"""Results from a UDora attack execution.
Attributes:
best_loss: Lowest loss achieved during optimization
best_string: Adversarial string that achieved the best loss
best_generation: Model generation with the best adversarial string
best_success: Whether the best attempt successfully triggered target action
last_string: Final adversarial string from optimization
last_generation: Model generation with the final adversarial string
last_success: Whether the final attempt successfully triggered target action
vanilla_generation: Model generation without any adversarial string
vanilla_success: Whether vanilla generation triggered target action
all_generation: Complete history of generations during optimization
losses: Loss values throughout optimization process
strings: Adversarial strings throughout optimization process
"""
# Best results
best_loss: float
best_string: str
best_generation: Union[str, List[str]]
best_success: Union[bool, List[bool]]
# Final results
last_string: str
last_generation: Union[str, List[str]]
last_success: Union[bool, List[bool]]
# Baseline results
vanilla_generation: Union[str, List[str]]
vanilla_success: Union[bool, List[bool]]
# Optimization history
all_generation: List[list]
losses: List[float]
strings: List[str]
class AttackBuffer:
"""Buffer for maintaining the best adversarial candidates during optimization.
Keeps track of the top candidates based on loss values, enabling efficient
selection of promising adversarial strings for continued optimization.
"""
def __init__(self, size: int):
"""Initialize attack buffer.
Args:
size: Maximum number of candidates to maintain (0 for single candidate)
"""
self.buffer = [] # Elements are (loss: float, optim_ids: Tensor)
self.size = size
def add(self, loss: float, optim_ids: Tensor) -> None:
"""Add a new candidate to the buffer.
Args:
loss: Loss value for the candidate
optim_ids: Token IDs of the adversarial string
"""
if self.size == 0:
self.buffer = [(loss, optim_ids)]
return
if len(self.buffer) < self.size:
self.buffer.append((loss, optim_ids))
else:
# Only add if new loss is better than the worst in buffer
worst_loss = max(item[0] for item in self.buffer)
if loss < worst_loss:
# Remove the worst item and add the new one
worst_idx = max(range(len(self.buffer)), key=lambda i: self.buffer[i][0])
self.buffer[worst_idx] = (loss, optim_ids)
else:
# New candidate is not better than any existing, don't add
return
# Keep buffer sorted by loss (best first)
self.buffer.sort(key=lambda x: x[0])
def get_best_ids(self) -> Tensor:
"""Get token IDs of the best candidate."""
if not self.buffer:
raise RuntimeError("Cannot get best IDs from empty attack buffer")
return self.buffer[0][1]
def get_lowest_loss(self) -> float:
"""Get the lowest (best) loss value."""
if not self.buffer:
raise RuntimeError("Cannot get lowest loss from empty attack buffer")
return self.buffer[0][0]
def get_highest_loss(self) -> float:
"""Get the highest (worst) loss value."""
if not self.buffer:
raise RuntimeError("Cannot get highest loss from empty attack buffer")
return self.buffer[-1][0]
def log_buffer(self, tokenizer):
"""Log current buffer contents for debugging."""
message = "Attack buffer contents:"
for loss, ids in self.buffer:
optim_str = tokenizer.batch_decode(ids)[0]
optim_str = optim_str.replace("\\", "\\\\")
optim_str = optim_str.replace("\n", "\\n")
message += f"\n loss: {loss:.4f} | string: {optim_str}"
logger.info(message)
def sample_ids_from_grad(
ids: Tensor,
grad: Tensor,
search_width: int,
topk: int = 256,
n_replace: int = 1,
not_allowed_ids: Optional[Tensor] = None,
):
"""Returns `search_width` combinations of token ids based on the token gradient.
Args:
ids : Tensor, shape = (n_optim_ids)
the sequence of token ids that are being optimized
grad : Tensor, shape = (n_optim_ids, vocab_size)
the gradient of the GCG loss computed with respect to the one-hot token embeddings
search_width : int
the number of candidate sequences to return
topk : int
the topk to be used when sampling from the gradient
n_replace : int
the number of token positions to update per sequence
not_allowed_ids : Tensor, shape = (n_ids)
the token ids that should not be used in optimization
Returns:
sampled_ids : Tensor, shape = (search_width, n_optim_ids)
sampled token ids
"""
n_optim_tokens = len(ids)
original_ids = ids.repeat(search_width, 1)
if not_allowed_ids is not None:
grad[:, not_allowed_ids.to(grad.device)] = float("inf")
topk_ids = (-grad).topk(topk, dim=1).indices
sampled_ids_pos = torch.argsort(torch.rand((search_width, n_optim_tokens), device=grad.device))[..., :n_replace]
sampled_ids_val = torch.gather(
topk_ids[sampled_ids_pos],
2,
torch.randint(0, topk, (search_width, n_replace, 1), device=grad.device)
).squeeze(2)
new_ids = original_ids.scatter_(1, sampled_ids_pos, sampled_ids_val)
return new_ids
def sample_ids_from_grad_readable(
ids: Tensor,
grad: Tensor,
search_width: int,
readable_optimizer=None,
readable_strength: float = 0.3,
topk: int = 256,
n_replace: int = 1,
not_allowed_ids: Tensor = None,
):
"""Sample candidate token sequences with readable guidance.
This function extends the standard gradient sampling with readable guidance
to encourage natural language patterns in adversarial strings using perplexity.
Args:
ids: Current token sequence being optimized, shape (n_optim_ids,)
grad: Gradient w.r.t. one-hot token embeddings, shape (n_optim_ids, vocab_size)
search_width: Number of candidate sequences to generate
readable_optimizer: ReadableOptimizer instance for guidance
readable_strength: Strength of readable guidance (0.0 to 1.0)
topk: Number of top gradient directions to consider per position
n_replace: Number of token positions to modify per candidate
not_allowed_ids: Token IDs to exclude from optimization (e.g., non-ASCII)
Returns:
Tensor of sampled token sequences, shape (search_width, n_optim_ids)
"""
n_optim_tokens = len(ids)
original_ids = ids.repeat(search_width, 1)
if not_allowed_ids is not None:
grad[:, not_allowed_ids.to(grad.device)] = float("inf")
# Apply readable guidance to gradients if enabled
if readable_optimizer is not None and readable_strength > 0.0:
for pos in range(n_optim_tokens):
grad[pos] = readable_optimizer.apply_readable_guidance(
grad[pos], ids, pos, readable_strength
)
topk_ids = (-grad).topk(topk, dim=1).indices
sampled_ids_pos = torch.argsort(torch.rand((search_width, n_optim_tokens), device=grad.device))[..., :n_replace]
sampled_ids_val = torch.gather(
topk_ids[sampled_ids_pos],
2,
torch.randint(0, topk, (search_width, n_replace, 1), device=grad.device)
).squeeze(2)
new_ids = original_ids.scatter_(1, sampled_ids_pos, sampled_ids_val)
return new_ids
def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer):
"""Filter out token sequences that change after decode-encode round trip.
This ensures that adversarial strings remain consistent when processed
through the tokenizer's decode/encode cycle, which is important for
maintaining attack effectiveness.
Args:
ids: Token ID sequences to filter, shape (search_width, n_optim_ids)
tokenizer: Model's tokenizer for decode/encode operations
Returns:
Filtered token sequences that survive round-trip, shape (new_search_width, n_optim_ids)
Raises:
RuntimeError: If no sequences survive filtering (suggests bad initialization)
"""
ids_decoded = tokenizer.batch_decode(ids)
filtered_ids = []
for i in range(len(ids_decoded)):
# Retokenize the decoded token ids
ids_encoded = tokenizer(ids_decoded[i], return_tensors="pt", add_special_tokens=False).to(ids.device)["input_ids"][0]
if torch.equal(ids[i], ids_encoded):
filtered_ids.append(ids[i])
if not filtered_ids:
# This occurs in some cases, e.g. using the Llama-3 tokenizer with a bad initialization
raise RuntimeError(
"No token sequences are the same after decoding and re-encoding. "
"Consider setting `filter_ids=False` or trying a different `optim_str_init`"
)
return torch.stack(filtered_ids)
class UDora:
"""UDora: Unified Red Teaming Framework for LLM Agents.
This class implements the core UDora attack algorithm that dynamically hijacks
LLM agents' reasoning processes to trigger targeted malicious actions. The attack
works by:
1. Gathering the agent's initial reasoning response
2. Identifying optimal positions for inserting target "noise"
3. Optimizing adversarial strings to maximize noise likelihood
4. Iteratively refining the attack based on agent's evolving reasoning
The approach adapts to different agent reasoning styles and can handle both
malicious environment and malicious instruction scenarios.
"""
def __init__(
self,
model: transformers.PreTrainedModel,
tokenizer: transformers.PreTrainedTokenizer,
config: UDoraConfig,
):
"""Initialize UDora attack framework.
Args:
model: Target LLM to attack
tokenizer: Tokenizer for the target model
config: Attack configuration parameters
"""
self.model = model
self.tokenizer = tokenizer
self.config = config
# Core components
self.embedding_layer = model.get_input_embeddings()
self.not_allowed_ids = None if config.allow_non_ascii else get_nonascii_toks(tokenizer, device=model.device)
self.batch_prefix_cache = None
# Initialize readable optimizer if enabled
self.readable_optimizer = None
if config.readable:
self.readable_optimizer = create_readable_optimizer(tokenizer, model)
logger.info("Readable optimization enabled - adversarial strings will be optimized for natural language patterns")
# Attack state
self.stop_flag = False
# Performance warnings
if model.dtype in (torch.float32, torch.float64):
logger.warning(f"Model is in {model.dtype}. Use a lower precision data type, if possible, for much faster optimization.")
if model.device == torch.device("cpu"):
logger.warning("Model is on the CPU. Use a hardware accelerator for faster optimization.")
# Handle missing chat template
if not tokenizer.chat_template:
logger.warning("Tokenizer does not have a chat template. Assuming base model and setting chat template to empty.")
tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}"
def run(
self,
messages: Union[List[Union[str, List[dict]]], Union[str, List[dict]]],
targets: Union[List[Union[str, List[str]]], str],
) -> UDoraResult:
"""Execute UDora attack to generate adversarial strings.
This is the main entry point for running UDora attacks. It handles both
single and batch inputs, supporting various message formats and target
specifications.
Args:
messages: Input messages/conversations to attack. Can be:
- Single string
- List of strings
- Single conversation (list of dicts with 'role'/'content')
- List of conversations
targets: Target actions/tools to trigger. Can be:
- Single target string
- List of target strings
- List of lists for multiple targets per message
Returns:
UDoraResult containing attack results, including best/final adversarial
strings, generated responses, success indicators, and optimization history.
"""
model = self.model
tokenizer = self.tokenizer
config = self.config
if config.seed is not None:
set_seed(config.seed)
torch.use_deterministic_algorithms(True, warn_only=True)
# Process input messages
messages = self._process_input_messages(messages)
self.targets = self._process_targets(targets, len(messages))
# Initialize batch state
self._initialize_batch_state(len(messages))
# Prepare templates
self.messages = messages
templates = self._prepare_templates(messages)
self.templates = templates
# Use provided readable initialization
self._update_target_and_context(config.optim_str_init)
# Initialize the attack buffer
buffer = self.init_buffer()
optim_ids = buffer.get_best_ids()
losses = []
optim_strings = []
# Main optimization loop
for step in tqdm(range(config.num_steps)):
if self.stop_flag:
losses.append(buffer.get_lowest_loss())
optim_ids = buffer.get_best_ids()
optim_str = tokenizer.batch_decode(optim_ids)[0]
optim_strings.append(optim_str)
break
# Compute gradients and sample candidates
optim_ids_onehot_grad = self.compute_token_gradient(optim_ids)
with torch.no_grad():
# Sample candidates with readable guidance
if config.readable and self.readable_optimizer:
sampled_ids = sample_ids_from_grad_readable(
optim_ids.squeeze(0),
optim_ids_onehot_grad.squeeze(0),
config.search_width,
self.readable_optimizer,
readable_strength=0.3, # Moderate strength for readability
topk=config.topk,
n_replace=config.n_replace,
not_allowed_ids=self.not_allowed_ids,
)
else:
sampled_ids = sample_ids_from_grad(
optim_ids.squeeze(0),
optim_ids_onehot_grad.squeeze(0),
config.search_width,
topk=config.topk,
n_replace=config.n_replace,
not_allowed_ids=self.not_allowed_ids,
)
if config.filter_ids:
sampled_ids = filter_ids(sampled_ids, tokenizer)
# Compute loss on candidates
loss = self._compute_batch_loss(sampled_ids)
# Ensure loss is a tensor for min/argmin operations
if not isinstance(loss, torch.Tensor):
loss = torch.tensor(loss, device=self.model.device)
current_loss = loss.min().item()
optim_ids = sampled_ids[loss.argmin()].unsqueeze(0)
# Update buffer
losses.append(current_loss)
if buffer.size == 0 or len(buffer.buffer) < buffer.size or current_loss < buffer.get_highest_loss():
buffer.add(current_loss, optim_ids)
optim_ids = buffer.get_best_ids()
optim_str = tokenizer.batch_decode(optim_ids)[0]
optim_strings.append(optim_str)
# Log readability metrics if optimization is enabled
if self.config.readable and self.readable_optimizer is not None:
readability_metrics = self.readable_optimizer.evaluate_readability(optim_str)
logger.info(f"Readability Score: {readability_metrics['readability_score']:.3f} "
f"(perplexity: {readability_metrics['perplexity']:.2f}, "
f"overall: {readability_metrics['overall_score']:.3f})")
buffer.log_buffer(tokenizer)
# Update context periodically
if (step + 1) % self.config.prefix_update_frequency == 0:
self._update_target_and_context(optim_str)
else:
self._generate_and_store_responses(optim_str)
if self.stop_flag:
logger.info("Early stopping due to finding a perfect match.")
break
# Generate final results
return self._compile_results(losses, optim_strings)
def _process_input_messages(self, messages):
"""Process and normalize input messages."""
if isinstance(messages, str):
return [[{"role": "user", "content": messages}]]
elif check_type(messages, str):
return [[{"role": "user", "content": message}] for message in messages]
elif check_type(messages, dict):
return [messages]
return copy.deepcopy(messages)
def _process_targets(self, targets, num_messages):
"""Process and normalize target specifications."""
if isinstance(targets, str):
return [[targets]] * num_messages
elif check_type(targets, str):
return [[target] for target in targets]
else:
return [[target for target in target_list] for target_list in targets]
def _initialize_batch_state(self, batch_size):
"""Initialize batch processing state variables."""
self.batch_before_ids = [[] for _ in range(batch_size)]
self.batch_after_ids = [[] for _ in range(batch_size)]
self.batch_target_ids = [[] for _ in range(batch_size)]
self.batch_before_embeds = [[] for _ in range(batch_size)]
self.batch_after_embeds = [[] for _ in range(batch_size)]
self.batch_target_embeds = [[] for _ in range(batch_size)]
self.batch_prefix_cache = [[] for _ in range(batch_size)]
self.batch_all_generation = [[] for _ in range(batch_size)]
def _prepare_templates(self, messages):
"""Prepare conversation templates from messages."""
templates = [
self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
if isinstance(message, list) else message
for message in messages
]
def bos_filter(template):
if self.tokenizer.bos_token and template.startswith(self.tokenizer.bos_token):
return template.replace(self.tokenizer.bos_token, "")
return template
templates = list(map(bos_filter, templates))
# Add optim_str placeholder if not present
for i, conversation in enumerate(messages):
if isinstance(conversation, list):
if not any("{optim_str}" in message.get("content", "") for message in conversation):
conversation[-1]["content"] += "{optim_str}"
# Regenerate template
templates[i] = self.tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
templates[i] = bos_filter(templates[i])
return templates
def _compute_batch_loss(self, sampled_ids):
"""Compute loss for a batch of candidate sequences."""
new_search_width = sampled_ids.shape[0]
batch_size = new_search_width if self.config.batch_size is None else self.config.batch_size
total_loss = 0
for idx in range(len(self.batch_before_embeds)):
if self.batch_prefix_cache and self.batch_prefix_cache[idx]:
input_embeds = torch.cat([
self.embedding_layer(sampled_ids),
self.batch_after_embeds[idx].repeat(new_search_width, 1, 1),
*[target_embedding.repeat(new_search_width, 1, 1)
for target_embedding in self.batch_target_embeds[idx]]
], dim=1)
else:
input_embeds = torch.cat([
self.batch_before_embeds[idx].repeat(new_search_width, 1, 1),
self.embedding_layer(sampled_ids),
self.batch_after_embeds[idx].repeat(new_search_width, 1, 1),
*[target_embedding.repeat(new_search_width, 1, 1)
for target_embedding in self.batch_target_embeds[idx]]
], dim=1)
loss_result = find_executable_batch_size(self.compute_candidates_loss, batch_size)(input_embeds, idx)
total_loss += loss_result
return total_loss / len(self.batch_before_embeds)
def _compute_generation_success(self, generation_str: str, batch_idx: int) -> bool:
"""Compute success status for a single generation."""
if self.config.minimize_reward:
return all(target not in generation_str for target in self.targets[batch_idx])
else:
return check_success_condition(generation_str, self.targets[batch_idx], self.config.dataset)
def _generate_and_store_responses(self, optim_str):
"""Generate responses for current adversarial string and store them."""
for k, template in enumerate(self.templates):
outputs, decoding_string = self.greedy_generation(template, optim_str, k)
print(decoding_string)
# Log the decoded string when verbosity is INFO or DEBUG
logger.info(f"Model's response: {decoding_string}")
# Compute success status for this generation
success = self._compute_generation_success(decoding_string, k)
self.batch_all_generation[k].append((optim_str, decoding_string, success))
def _compile_results(self, losses, optim_strings):
"""Compile final attack results."""
min_loss_index = losses.index(min(losses))
# Generate responses for evaluation
vanilla_generation = [self.greedy_generation(template, '')[1] for template in self.templates]
last_generation = [self.greedy_generation(template, optim_strings[-1])[1]
for template in self.templates]
# Add last_generation to all_generation for each batch
last_success = []
for k, gen in enumerate(last_generation):
success = self._compute_generation_success(gen, k)
last_success.append(success)
# Add the last generation to batch_all_generation
self.batch_all_generation[k].append((optim_strings[-1], gen, success))
# Compute success metrics
def compute_success(generation_list):
if self.config.minimize_reward:
return [all(target not in gen for target in self.targets[i])
for i, gen in enumerate(generation_list)]
else:
return [check_success_condition(gen, self.targets[i], self.config.dataset)
for i, gen in enumerate(generation_list)]
# Find best result among successful generations, or fallback to min loss
best_index = self._find_best_successful_index(losses, optim_strings)
best_generation = [self.greedy_generation(template, optim_strings[best_index])[1]
for template in self.templates]
return UDoraResult(
best_loss=losses[best_index],
best_string=optim_strings[best_index],
best_generation=best_generation,
best_success=compute_success(best_generation),
last_string=optim_strings[-1],
last_generation=last_generation,
last_success=compute_success(last_generation),
vanilla_generation=vanilla_generation,
vanilla_success=compute_success(vanilla_generation),
all_generation=self.batch_all_generation,
losses=losses,
strings=optim_strings,
)
def _find_best_successful_index(self, losses, optim_strings):
"""Find the index of the best successful attack, or fallback to minimum loss."""
successful_indices = []
# Determine the maximum number of generations across all batches
max_generations = max(len(self.batch_all_generation[k]) for k in range(len(self.batch_all_generation))) if self.batch_all_generation else 0
# Check each generation step for success
for step_idx in range(max_generations):
step_successful = True
for batch_idx in range(len(self.batch_all_generation)):
# Find the generation corresponding to this step
if step_idx < len(self.batch_all_generation[batch_idx]):
generation_tuple = self.batch_all_generation[batch_idx][step_idx]
# Extract success status from the tuple
# For tuples with success: (optim_str, decoding_string, success) or (optim_str, decoding_string, later_response, success)
if len(generation_tuple) >= 3:
# Check if the last element is a boolean (success status)
if isinstance(generation_tuple[-1], bool):
batch_success = generation_tuple[-1]
else:
# Fallback: compute success from generation string
generation_str = generation_tuple[1] if len(generation_tuple) > 1 else ""
batch_success = self._compute_generation_success(generation_str, batch_idx)
else:
# Fallback for old tuple format without success status
generation_str = generation_tuple[1] if len(generation_tuple) > 1 else ""
batch_success = self._compute_generation_success(generation_str, batch_idx)
if not batch_success:
step_successful = False
break
else:
# If we don't have generation for this step, consider it unsuccessful
step_successful = False
break
if step_successful:
successful_indices.append(step_idx)
# Among successful attacks, find the one with minimum loss
if successful_indices:
# For successful indices, use the corresponding loss if available, otherwise use the minimum loss
if successful_indices[0] < len(losses):
best_successful_idx = min(successful_indices, key=lambda idx: losses[idx] if idx < len(losses) else float('inf'))
logger.info(f"Found {len(successful_indices)} successful attacks. Best successful at step {best_successful_idx} with loss {losses[best_successful_idx] if best_successful_idx < len(losses) else 'N/A'}.")
return best_successful_idx
else:
# Early stopping case: successful attack found before any optimization steps
logger.info(f"Found {len(successful_indices)} successful attacks during initialization (before optimization loop).")
return 0 # Return index 0 to use the first available loss/string
else:
# No successful attacks found, fallback to minimum loss
if losses:
min_loss_idx = losses.index(min(losses))
logger.info(f"No successful attacks found. Using minimum loss at step {min_loss_idx} with loss {losses[min_loss_idx]:.4f}")
return min_loss_idx
else:
# No losses available (early stopping before any optimization)
logger.info("No successful attacks found and no optimization losses available. Using index 0.")
return 0
def init_buffer(self) -> AttackBuffer:
"""Initialize the attack buffer with initial adversarial candidates."""
model = self.model
tokenizer = self.tokenizer
config = self.config
logger.info(f"Initializing attack buffer of size {config.buffer_size}...")
buffer = AttackBuffer(config.buffer_size)
if isinstance(config.optim_str_init, str):
init_optim_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device)
else:
if len(config.optim_str_init) != config.buffer_size:
logger.warning(f"Using {len(config.optim_str_init)} initializations but buffer size is set to {config.buffer_size}")
try:
init_optim_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device)
except ValueError:
logger.error("Unable to create buffer. Ensure that all initializations tokenize to the same length.")
true_buffer_size = 1
# Compute initial losses
init_buffer_losses = self._compute_batch_loss(init_optim_ids)
# Ensure init_buffer_losses is a tensor
if isinstance(init_buffer_losses, torch.Tensor):
if init_buffer_losses.dim() == 0:
# Scalar tensor, expand to match buffer size
init_buffer_losses = init_buffer_losses.unsqueeze(0).expand(true_buffer_size)
else:
# Convert to tensor if it's a scalar
init_buffer_losses = torch.tensor([init_buffer_losses], device=model.device)
# Populate buffer
for i in range(true_buffer_size):
if isinstance(init_buffer_losses, torch.Tensor):
loss_val = init_buffer_losses[i].item() if init_buffer_losses.numel() > i else init_buffer_losses.item()
else:
loss_val = float(init_buffer_losses)
buffer.add(loss_val, init_optim_ids[[i]])
buffer.log_buffer(tokenizer)
logger.info("Initialized attack buffer.")
return buffer
def compute_token_gradient(self, optim_ids: Tensor) -> Tensor:
"""Compute gradients of UDora loss w.r.t. one-hot token embeddings."""
model = self.model
embedding_layer = self.embedding_layer
# Create one-hot encoding
optim_ids_onehot = torch.nn.functional.one_hot(optim_ids, num_classes=embedding_layer.num_embeddings)
optim_ids_onehot = optim_ids_onehot.to(dtype=model.dtype, device=model.device)
optim_ids_onehot.requires_grad_()
optim_embeds = optim_ids_onehot @ embedding_layer.weight
# Compute loss across all batches
total_loss = 0
for idx in range(len(self.batch_before_embeds)):
if self.batch_prefix_cache and self.batch_prefix_cache[idx]:
input_embeds = torch.cat([optim_embeds, self.batch_after_embeds[idx], *self.batch_target_embeds[idx]], dim=1)
output = model(inputs_embeds=input_embeds, past_key_values=self.batch_prefix_cache[idx])
else:
input_embeds = torch.cat([self.batch_before_embeds[idx], optim_embeds, self.batch_after_embeds[idx], *self.batch_target_embeds[idx]], dim=1)
output = model(inputs_embeds=input_embeds)
logits = output.logits
batch_loss = self._compute_target_loss(logits, idx)
total_loss += batch_loss
mean_loss = total_loss / len(self.batch_before_embeds)
optim_ids_onehot_grad = torch.autograd.grad(outputs=[mean_loss], inputs=[optim_ids_onehot])[0]
if self.config.minimize_reward:
optim_ids_onehot_grad = -optim_ids_onehot_grad
return optim_ids_onehot_grad
def _compute_target_loss(self, logits, idx):
"""Compute loss for target sequences in a batch."""
shift = logits.shape[1] - sum([target_embedding.shape[1] for target_embedding in self.batch_target_embeds[idx]])
current_shift = shift - 1
total_loss = 0
for p in range(len(self.batch_target_embeds[idx])):
length = self.batch_target_embeds[idx][p].shape[1]
# Only compute loss on even indices (target sequences)
if p % 2 == 0:
positions = list(range(current_shift, current_shift + length))
if self.config.use_mellowmax:
loss = apply_mellowmax_loss(
logits, self.batch_target_ids[idx][p], positions,
self.config.mellowmax_alpha, self.config.weight, p//2
)
else:
loss = compute_cross_entropy_loss(
logits, self.batch_target_ids[idx][p], positions,
self.config.weight, p//2
)
total_loss += loss
current_shift += length
return total_loss
def compute_candidates_loss(self, search_batch_size: int, input_embeds: Tensor, idx: int) -> Tensor:
"""Compute UDora loss for candidate adversarial sequences."""
all_loss = []
prefix_cache_batch = []
prefix_cache = self.batch_prefix_cache[idx] if self.batch_prefix_cache else None
for i in range(0, input_embeds.shape[0], search_batch_size):
with torch.no_grad():
input_embeds_batch = input_embeds[i:i+search_batch_size]
current_batch_size = input_embeds_batch.shape[0]
if prefix_cache:
if not prefix_cache_batch or current_batch_size != search_batch_size:
prefix_cache_batch = [[x.expand(current_batch_size, -1, -1, -1) for x in prefix_cache[i]]
for i in range(len(prefix_cache))]
outputs = self.model(inputs_embeds=input_embeds_batch, past_key_values=prefix_cache_batch)
else:
outputs = self.model(inputs_embeds=input_embeds_batch)
logits = outputs.logits
# Compute UDora loss
loss = self._compute_udora_batch_loss(logits, idx, current_batch_size)
del outputs
gc.collect()
torch.cuda.empty_cache()
all_loss.append(loss)
result = torch.cat(all_loss, dim=0)
return -result if self.config.minimize_reward else result
def _compute_udora_batch_loss(self, logits, idx, batch_size):
"""Compute UDora loss for a batch of candidates."""
tmp = logits.shape[1] - sum([target_embedding.shape[1] for target_embedding in self.batch_target_embeds[idx]])
positions = []
all_shift_labels = []
current_shift = tmp - 1
for p in range(len(self.batch_target_embeds[idx])):
length = self.batch_target_embeds[idx][p].shape[1]
if p % 2 == 0:
positions.append(list(range(current_shift, current_shift + length)))
all_shift_labels.append(self.batch_target_ids[idx][p])
current_shift += length
current_loss = 0
active_mask = torch.ones(batch_size, dtype=torch.bool, device=self.model.device)
for p in range(len(positions)):
loss = compute_udora_loss(
logits, all_shift_labels[p], positions[p],
self.config.weight, p
)
loss = loss * active_mask.float()
current_loss += loss
if not self.config.sequential:
seq_len = len(positions[p])
active_mask = active_mask & (torch.ones(batch_size, dtype=torch.bool, device=self.model.device))
if not active_mask.any():
break
return current_loss
@torch.no_grad()
def _update_target_and_context(self, optim_str):
"""Update target positions and reasoning context using interval scheduling."""
model = self.model
tokenizer = self.tokenizer
embedding_layer = self.embedding_layer
for k, template in enumerate(self.templates):
# Generate response and get intervals
outputs, decoding_string = self.greedy_generation(template, optim_str, k)
# Log the decoded string when verbosity is INFO or DEBUG
generated_ids = outputs.generated_ids
logits = outputs.scores
probs_list = [logits_step.softmax(dim=-1)[0] for logits_step in logits]
# Build target intervals
intervals = build_target_intervals(
generated_ids, self.targets[k], tokenizer, probs_list,
self.config.add_space_before_target, self.config.before_negative
)
# Apply weighted interval scheduling
selected_indices = weighted_interval_scheduling(intervals, self.config.num_location)
# Filter based on sequential mode
selected_indices = filter_intervals_by_sequential_mode(
intervals, selected_indices, self.config.sequential
)
# Build final token sequence
final_generated_ids = build_final_token_sequence(
intervals, selected_indices, generated_ids, model.device
)
# Log results
matched_count = count_matched_locations(intervals)
logger.info(f"Batch {k}: Found {len(intervals)} intervals, {matched_count} matched, selected {len(selected_indices)}")
all_response_ids = torch.cat(final_generated_ids, dim=-1) if final_generated_ids else torch.tensor([], device=model.device)
surrogate_response = tokenizer.decode(all_response_ids[0] if len(all_response_ids.shape) > 1 else all_response_ids, skip_special_tokens=False)
# print(later_response)
logger.info(f"Target context update:\n**Generated response**: {decoding_string}\n**Surrogate response**: {surrogate_response}")
# Update embeddings and context
self._update_batch_embeddings(k, template, final_generated_ids)
# Compute success status for this generation
success = self._compute_generation_success(decoding_string, k)
self.batch_all_generation[k].append((optim_str, decoding_string, surrogate_response, success))
del outputs
gc.collect()
torch.cuda.empty_cache()
def _update_batch_embeddings(self, k, template, final_generated_ids):
"""Update batch embeddings and cache for a specific batch index."""
before_str, after_str = template.split("{optim_str}")
before_ids = self.tokenizer([before_str.rstrip()], add_special_tokens=False, padding=False, return_tensors="pt")["input_ids"].to(self.model.device)
after_ids = self.tokenizer([after_str.lstrip()], add_special_tokens=False, padding=False, return_tensors="pt")["input_ids"].to(self.model.device)
if final_generated_ids:
assistant_response_ids = final_generated_ids.pop(0)
after_ids = torch.cat((after_ids, assistant_response_ids), dim=-1)
# Update embeddings
before_embeds, after_embeds = [self.embedding_layer(ids) for ids in (before_ids, after_ids)]
target_embeds = [self.embedding_layer(ids) for ids in final_generated_ids]
# Update prefix cache if needed
if self.config.use_prefix_cache and isinstance(self.batch_before_embeds[k], list):
output = self.model(inputs_embeds=before_embeds, use_cache=True)
self.batch_prefix_cache[k] = output.past_key_values
self.batch_before_ids[k] = before_ids
self.batch_before_embeds[k] = before_embeds
self.batch_after_ids[k] = after_ids
self.batch_target_ids[k] = final_generated_ids
self.batch_after_embeds[k] = after_embeds
self.batch_target_embeds[k] = target_embeds
@torch.no_grad()
def greedy_generation(self, template, optim_str, idx=None):
"""Generate agent response using greedy decoding."""
# Prepare input
if optim_str == '':
input_ids = self.tokenizer([template.replace(' {optim_str}', '')],
add_special_tokens=False, padding=False, return_tensors="pt")["input_ids"].to(self.model.device)
else:
input_ids = self.tokenizer([template.replace('{optim_str}', optim_str)],
add_special_tokens=False, padding=False, return_tensors="pt")["input_ids"].to(self.model.device)
attn_masks = torch.ones_like(input_ids).to(self.model.device)
# Generate response
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attn_masks,
top_k=0,
top_p=1.0,
do_sample=False,
output_scores=True,
return_dict_in_generate=True,
output_hidden_states=False,
output_attentions=False,
pad_token_id=self.tokenizer.eos_token_id,
past_key_values=self.batch_prefix_cache[idx] if idx is not None and self.batch_prefix_cache and self.batch_prefix_cache[idx] else None,
max_new_tokens=self.config.max_new_tokens,
)
outputs.generated_ids = outputs.sequences[0, input_ids.shape[1]:].tolist()
if self.tokenizer.eos_token_id and outputs.generated_ids and outputs.generated_ids[-1] == self.tokenizer.eos_token_id:
outputs.generated_ids = outputs.generated_ids[:-1]
decoded_string = self.tokenizer.decode(outputs.generated_ids, skip_special_tokens=False)
# Check for early stopping
if idx is not None and self.config.early_stop:
if self.config.minimize_reward:
if all(tgt_text not in decoded_string for tgt_text in self.targets[idx]):
self.stop_flag = True
elif check_success_condition(decoded_string, self.targets[idx], self.config.dataset):
self.stop_flag = True
return outputs, decoded_string
def run(
model: transformers.PreTrainedModel,
tokenizer: transformers.PreTrainedTokenizer,
messages: Union[str, List[dict]],
target: str,
config: Optional[UDoraConfig] = None,
) -> UDoraResult:
"""Execute UDora attack with simplified API."""
if config is None:
config = UDoraConfig()
logger.setLevel(getattr(logging, config.verbosity))
udora = UDora(model, tokenizer, config)
result = udora.run(messages, target)
return result