Add unrolled methods (claude_oss_v53, claude_v63) with README

Rename claude_v53 to claude_oss_v53 to match safeguard track naming
convention. Add README documenting what unrolled methods are and how
to create them.
This commit is contained in:
Peter Romov
2026-03-25 18:27:52 +00:00
parent 5b6058b3c4
commit 63974ddfee
6 changed files with 624 additions and 0 deletions
@@ -0,0 +1,18 @@
# Unrolled Methods
Unrolled versions are **standalone, cleaned-up rewrites** of existing methods. The originals often inherit from each other in long chains and accumulate dead code. An unrolled version flattens the inheritance, keeps only the logic that matters, and makes the method readable as a single file.
## Why unroll
- **Auditability**: one file, no inheritance maze, every line does something.
- **Reproducibility**: self-contained code is easier to port or share.
- **Documentation**: the module docstring becomes the ground truth for what the method does.
## How to create one
1. **Read the original** method chain top-to-bottom. Identify which pieces of inherited logic are actually used.
2. **Write a flat optimizer** that subclasses `TokenOptimizer` directly. Inline all logic into `setup()`, `step()`, and a few private helpers. No intermediate base classes.
3. **Simplify**: remove dead branches, unused hyperparameters, and compatibility shims. Three clear lines beat one clever abstraction.
4. **Verify equivalence**: run a small experiment with the same seed on both the original and unrolled version. The results must match bit-for-bit. Run a few configs if the method has branching logic (e.g. different `n_replace` phases).
5. **Write the module docstring**: one-line summary of what the method combines, a numbered list of component ideas with short explanations and paper references, and a pseudocode block showing the full optimization loop.
6. **Name it** `<original_method_name>_unrolled`, matching the original's `method_name`. Place it in `claude_unrolled/<method_dir>/` with `optimizer.py` and `__init__.py`.
@@ -0,0 +1 @@
@@ -0,0 +1,3 @@
from .optimizer import ClaudeOssV53UnrolledOptimizer
__all__ = ["ClaudeOssV53UnrolledOptimizer"]
@@ -0,0 +1,301 @@
"""Claude OSS v53 (unrolled): MAC + TAO DPTO with Coarse-to-Fine Replacement.
Combines three ideas:
1. **DPTO candidate selection** (Direction-Priority Token Optimization).
For each position, filters vocabulary tokens by cosine similarity
between the negative gradient direction and candidate displacement
vectors, then samples replacements from the filtered set via
temperature-scaled softmax over projected step magnitudes.
Reference: "TAO-Attack: Toward Advanced Optimization-Based Jailbreak
Attacks for Large Language Models" (Xu et al., ICLR 2026,
arXiv:2603.03081).
2. **Momentum-smoothed embedding gradients** (MAC).
An exponential moving average (EMA) of the embedding-space gradient
replaces the raw per-step gradient, reducing noise in the search
direction. The momentum gradient is fed into DPTO for candidate
selection.
Reference: "Boosting Jailbreak Attack with Momentum" (Zhang & Wei,
ICASSP 2025).
3. **Coarse-to-fine replacement schedule** (novel, claude_oss_v53).
For the first 80% of optimization steps, each candidate replaces
n_replace=2 positions (coarse exploration). For the final 20%,
the method switches to n_replace=1 (fine-grained refinement of
individual positions).
Pseudocode::
x ~ random tokens # [L]
m = None # momentum buffer
for each step:
# --- embedding gradient ---
embed = one_hot(x) @ W_embed # [1, L, D]
loss = CE(model([prefix | embed | suffix | target]), target)
g = d(loss)/d(embed) # [1, L, D]
# --- momentum update (MAC) ---
m = mu * m + (1 - mu) * g # EMA
# --- DPTO candidate selection (TAO) ---
for each position:
cos = cosine(g_pos, embed_pos - W_embed)
top_indices = topk(cos, topk_per_position)
dot_scores = g_pos . (embed_pos - W_embed[top_indices])
probs = softmax(dot_scores / temperature)
sample B candidates, each replacing n_replace positions
# --- coarse-to-fine schedule ---
n_replace = 2 if step < 0.8 * total_steps else 1
# --- evaluate and keep best ---
losses = [CE(model([prefix | cand | suffix | target])) for cand in candidates]
x = candidates[argmin(losses)]
"""
import logging
import torch
from torch import Tensor
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from claudini.base import TokenOptimizer
logger = logging.getLogger("claudini")
class ClaudeOssV53UnrolledOptimizer(TokenOptimizer):
"""MAC + TAO DPTO with coarse-to-fine replacement. See module docstring."""
method_name = "claude_oss_v53_unrolled"
is_soft = False
# -- Hyperparameter defaults ------------------------------------------------
DEFAULT_NUM_CANDIDATES = 80
DEFAULT_TOPK_PER_POSITION = 300
DEFAULT_TEMPERATURE = 0.4
DEFAULT_N_REPLACE = 2
DEFAULT_MOMENTUM = 0.908
DEFAULT_SWITCH_FRACTION = 0.8
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
optim_length: int = 20,
num_candidates: int = DEFAULT_NUM_CANDIDATES,
topk_per_position: int = DEFAULT_TOPK_PER_POSITION,
temperature: float = DEFAULT_TEMPERATURE,
n_replace: int = DEFAULT_N_REPLACE,
momentum: float = DEFAULT_MOMENTUM,
switch_fraction: float = DEFAULT_SWITCH_FRACTION,
seed: int | None = None,
allow_non_ascii: bool = True,
):
super().__init__(model, tokenizer, optim_length, seed, allow_non_ascii)
# Hyperparameters
self.num_candidates = num_candidates
self.topk_per_position = topk_per_position
self.temperature = temperature
self.n_replace = n_replace
self.momentum = momentum
self.switch_fraction = switch_fraction
# State (populated in setup)
self.current_ids: Tensor | None = None
self.momentum_grad: Tensor | None = None
self._estimated_steps = 131
# -- Setup ------------------------------------------------------------------
def setup(self, prompt: str, target: str) -> None:
self._prepare_prompt(prompt, target)
self.current_ids = self._init_optim_ids().unsqueeze(0)
self.momentum_grad = None
logger.info(
"Claude OSS v53 (unrolled): B=%d, topk=%d, temp=%.2f, n_replace=%d->1 at %.0f%%, momentum=%.3f",
self.num_candidates,
self.topk_per_position,
self.temperature,
self.n_replace,
self.switch_fraction * 100,
self.momentum,
)
# -- Step -------------------------------------------------------------------
def step(self, step_num: int) -> tuple[float, float | None, str]:
# Coarse-to-fine: switch n_replace from 2 to 1 at 80% of budget
switch_step = int(self._estimated_steps * self.switch_fraction)
current_n_replace = 1 if step_num >= switch_step else self.n_replace
self.log("temperature", self.temperature, prog_bar=True)
self.log("n_replace", float(current_n_replace), prog_bar=True)
# 1. Compute embedding-space gradient (one fwd+bwd)
grad, optim_embeds = self._compute_embed_gradient(self.current_ids)
self.flop_counter.count_forward_backward(self.total_seq_len)
with torch.no_grad():
# 2. Update momentum on embedding gradient (MAC)
if self.momentum_grad is None:
self.momentum_grad = grad.clone()
else:
self.momentum_grad = self.momentum * self.momentum_grad + (1 - self.momentum) * grad
# 3. DPTO candidate selection using momentum gradient (TAO)
sampled_ids = self._dpto_sample(
self.current_ids.squeeze(0),
optim_embeds.squeeze(0),
self.momentum_grad.squeeze(0),
current_n_replace,
)
actual_B = sampled_ids.shape[0]
# 4. Evaluate candidates
batch_losses = self._eval_candidates(sampled_ids)
self.flop_counter.count_forward(self.total_seq_len, batch_size=actual_B)
# 5. Keep best
best_idx = batch_losses.argmin()
best_loss = float(batch_losses[best_idx].item())
self.current_ids = sampled_ids[best_idx].unsqueeze(0)
optim_str = self.tokenizer.batch_decode(self.current_ids)[0]
self._step_ids = self.current_ids.squeeze(0)
return best_loss, None, optim_str
# -- Embedding gradient -----------------------------------------------------
def _compute_embed_gradient(self, optim_ids: Tensor) -> tuple[Tensor, Tensor]:
"""Compute gradient of CE loss w.r.t. token embeddings.
Returns:
grad: [1, L, D] gradient in embedding space
optim_embeds: [1, L, D] current token embeddings (detached)
"""
embedding_layer = self.embedding_layer
optim_ids_onehot = torch.nn.functional.one_hot(
optim_ids,
num_classes=embedding_layer.num_embeddings,
).to(self.model.device, self.model.dtype)
optim_embeds = (optim_ids_onehot @ embedding_layer.weight).detach().clone()
optim_embeds.requires_grad_()
input_embeds = torch.cat(
[self.before_embeds, optim_embeds, self.after_embeds, self.target_embeds],
dim=1,
)
output = self.model(inputs_embeds=input_embeds)
logits = output.logits
shift = input_embeds.shape[1] - self.target_ids.shape[1]
target_len = self.target_ids.shape[1]
shift_logits = logits[..., shift - 1 : shift - 1 + target_len, :].contiguous()
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
self.target_ids.view(-1),
)
grad = torch.autograd.grad(outputs=[loss], inputs=[optim_embeds])[0]
return grad, optim_embeds.detach()
# -- DPTO candidate sampling ------------------------------------------------
def _dpto_sample(
self,
control_toks: Tensor,
optim_embeds: Tensor,
grad: Tensor,
n_replace: int,
) -> Tensor:
"""Direction-Priority Token Optimization sampling using momentum gradient.
Args:
control_toks: [L] current suffix token ids
optim_embeds: [L, D] current token embeddings
grad: [L, D] momentum gradient in embedding space
n_replace: number of positions to replace per candidate
Returns:
new_ids: [B, L] candidate sequences
"""
eps = 1e-12
embed_weights = self.embedding_layer.weight.detach() # [V, D]
L, D = optim_embeds.shape
device = grad.device
# Step 1: Cosine similarity per position
grad_norm = grad / (grad.norm(dim=-1, keepdim=True) + eps)
topk = min(self.topk_per_position, embed_weights.shape[0])
top_indices = torch.empty(L, topk, device=device, dtype=torch.long)
for pos in range(L):
dir_pos = optim_embeds[pos] - embed_weights # [V, D]
dir_norm_pos = dir_pos / (dir_pos.norm(dim=-1, keepdim=True) + eps)
cos_pos = grad_norm[pos] @ dir_norm_pos.T # [V]
# Mask forbidden tokens
if self.not_allowed_ids is not None:
cos_pos[self.not_allowed_ids.to(device)] = -float("inf")
cos_pos[control_toks[pos]] = -float("inf")
_, top_indices[pos] = cos_pos.topk(topk)
# Step 2: Projected step within filtered set
candidate_embeds = embed_weights[top_indices] # [L, k, D]
candidate_dirs = optim_embeds.unsqueeze(1) - candidate_embeds # [L, k, D]
dot_scores = torch.einsum("ld,lkd->lk", grad, candidate_dirs) # [L, k]
# Step 3: Temperature-scaled softmax sampling
probs = torch.softmax(dot_scores / max(self.temperature, eps), dim=1)
# Sample candidates
B = self.num_candidates
original_ids = control_toks.repeat(B, 1) # [B, L]
if n_replace == 1:
samples_per_pos = B // L
remainder = B % L
all_positions = []
all_tokens = []
for pos in range(L):
n = samples_per_pos + (1 if pos < remainder else 0)
if n > 0:
token_indices = torch.multinomial(probs[pos], n, replacement=True)
token_ids = top_indices[pos][token_indices]
all_positions.extend([pos] * n)
all_tokens.append(token_ids)
positions = torch.tensor(all_positions, device=device, dtype=torch.long)
tokens = torch.cat(all_tokens, dim=0)
original_ids[torch.arange(B, device=device), positions] = tokens
else:
for b in range(B):
pos_perm = torch.randperm(L, device=device)[:n_replace]
for pos in pos_perm:
token_idx = torch.multinomial(probs[pos], 1).item()
original_ids[b, pos] = top_indices[pos, token_idx]
return original_ids
# -- Candidate evaluation ---------------------------------------------------
def _eval_candidates(self, sampled_ids: Tensor) -> Tensor:
"""Evaluate loss on candidate sequences."""
actual_B = sampled_ids.shape[0]
embedding_layer = self.embedding_layer
input_embeds = torch.cat(
[
self.before_embeds.expand(actual_B, -1, -1),
embedding_layer(sampled_ids),
self.after_embeds.expand(actual_B, -1, -1),
self.target_embeds.expand(actual_B, -1, -1),
],
dim=1,
)
return self.batched_loss(input_embeds)
@@ -0,0 +1,3 @@
from .optimizer import ClaudeV63UnrolledOptimizer
__all__ = ["ClaudeV63UnrolledOptimizer"]
@@ -0,0 +1,298 @@
"""Claude v63 (unrolled): Decoupled ADC with Layer-wise Gradient Scaling.
Combines three ideas:
1. **ADC** (Adaptive Dense-to-sparse Constrained optimization).
Optimizes soft probability distributions z in [K, L, V] over the
vocabulary via SGD with heavy momentum. An adaptive sparsity
schedule gradually constrains each distribution from dense (full
vocabulary) to sparse (near one-hot) based on how many target tokens
the current restart mispredicts.
Reference: "Efficient LLM Jailbreak via Adaptive Dense-to-sparse
Constrained Optimization" (NeurIPS 2024, arXiv:2405.09113).
2. **Decoupled K/lr** (claude_v19).
The original ADC averages the CE loss over the K restarts, coupling
the effective gradient magnitude to K. Here the loss is *summed*
over restarts so that the learning rate is independent of K. This
lets K control exploration breadth while lr controls step size.
3. **LSGM — Layer-wise SGD with Gradual Momentum**.
Backward hooks on every LayerNorm module scale incoming gradients
by gamma < 1, amplifying the skip-connection gradient signal
relative to the residual branch. Originally proposed for GCG in
"Improved Generation of Adversarial Examples Against Safety-aligned
LLMs" (Li et al., NeurIPS 2024, arXiv:2405.20778); here applied
to ADC's continuous optimization with a milder gamma (0.85 vs the
paper's 0.5).
Pseudocode::
z ~ softmax(Normal(0, I)) # [K, L, V]
register backward hooks: grad *= gamma on all LayerNorm modules
for each step:
# --- soft forward ---
soft_embeds = z @ W_embed # [K, L, D]
logits = model([prefix | soft_embeds | suffix | target]) # [K, S, V]
loss_k = CE(logits, target).mean(over tokens) # [K]
loss = sum(loss_k) # scalar, decoupled from K
loss.backward()
SGD.step()
# --- adaptive sparsity ---
wrong_k = count_mispredictions(logits, target) # [K]
ema_wrong += alpha * (wrong_k - ema_wrong)
S_k = clamp(2 ^ ema_wrong_k, max=V/2)
z_pre = z.clone()
z = sparsify(z, S_k) # keep top-S per position
# --- discrete evaluation ---
ids_k = argmax(z_pre, dim=-1) # [K, L]
losses_k = CE_discrete(ids_k) # [K]
track global best across all steps
"""
import logging
import torch
from torch import Tensor
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from claudini.base import TokenOptimizer
logger = logging.getLogger("claudini")
# LayerNorm module name patterns (covers Llama, Gemma, GPT-2, etc.)
_NORM_PATTERNS = (
"input_layernorm",
"post_attention_layernorm",
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
".ln_1",
".ln_2",
)
class ClaudeV63UnrolledOptimizer(TokenOptimizer):
"""Decoupled ADC with LSGM gradient scaling. See module docstring."""
method_name = "claude_v63_unrolled"
is_soft = True
# ── Hyperparameter defaults ──────────────────────────────────────
DEFAULT_LR = 10.0
DEFAULT_MOMENTUM = 0.99
DEFAULT_EMA_ALPHA = 0.01
DEFAULT_NUM_STARTS = 6
DEFAULT_LSGM_GAMMA = 0.85
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
optim_length: int = 20,
lr: float = DEFAULT_LR,
momentum: float = DEFAULT_MOMENTUM,
ema_alpha: float = DEFAULT_EMA_ALPHA,
num_starts: int = DEFAULT_NUM_STARTS,
lsgm_gamma: float = DEFAULT_LSGM_GAMMA,
seed: int | None = None,
allow_non_ascii: bool = False,
):
super().__init__(model, tokenizer, optim_length, seed, allow_non_ascii)
# Hyperparameters
self.lr = lr
self.momentum = momentum
self.ema_alpha = ema_alpha
self.num_starts = num_starts
self.lsgm_gamma = lsgm_gamma
# State (populated in setup)
self.soft_opt: torch.nn.Parameter | None = None
self.optimizer: torch.optim.SGD | None = None
self.running_wrong: Tensor | None = None
self._global_best_loss: float = float("inf")
self._global_best_ids: Tensor | None = None
self._lsgm_handles: list = []
# ── Setup ────────────────────────────────────────────────────────
def setup(self, prompt: str, target: str) -> None:
self._prepare_prompt(prompt, target)
K = self.num_starts
device = self.model.device
# z ~ softmax(N(0, I)) for K restarts: [K, L, V]
z = torch.randn(K, self.optim_length, self.vocab_size, device=device)
if self.forbidden_mask is not None:
z[:, :, self.forbidden_mask] = -1e10
z = z.softmax(dim=-1)
self.soft_opt = torch.nn.Parameter(z)
self.optimizer = torch.optim.SGD([self.soft_opt], lr=self.lr, momentum=self.momentum)
self.running_wrong = None
self._global_best_loss = float("inf")
self._global_best_ids = None
# LSGM: backward hooks that scale LayerNorm gradients by gamma
self._lsgm_handles = self._register_lsgm_hooks()
logger.info(
"Claude v63 (unrolled): LSGM(%d hooks, gamma=%.2f), K=%d, lr=%.1f",
len(self._lsgm_handles),
self.lsgm_gamma,
self.num_starts,
self.lr,
)
# ── Step ─────────────────────────────────────────────────────────
def step(self, step_num: int) -> tuple[float, float | None, str]:
K = self.num_starts
self.optimizer.zero_grad()
# 1. Soft embeddings: [K, L, V] @ [V, D] -> [K, L, D]
W = self.embedding_layer.weight.detach()
soft_embeds = torch.matmul(
self.soft_opt.to(torch.float32),
W.to(torch.float32),
).to(self.model_dtype)
# 2. Batched forward
input_embeds = torch.cat(
[
self.before_embeds.expand(K, -1, -1),
soft_embeds,
self.after_embeds.expand(K, -1, -1),
self.target_embeds.expand(K, -1, -1),
],
dim=1,
)
logits = self.model(inputs_embeds=input_embeds).logits
shift = input_embeds.shape[1] - self.target_ids.shape[1]
target_len = self.target_ids.shape[1]
shift_logits = logits[..., shift - 1 : shift - 1 + target_len, :].contiguous()
# 3. CE loss: mean over tokens, SUM over K (decoupled from lr)
target_expanded = self.target_ids.expand(K, -1)
loss_per_token = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
target_expanded.reshape(-1),
reduction="none",
)
loss_per_restart = loss_per_token.view(K, target_len).mean(dim=1)
soft_loss = loss_per_restart.sum()
soft_loss_val = float(soft_loss.item() / K)
with torch.no_grad():
preds = shift_logits.argmax(dim=-1)
wrong_counts = (preds != target_expanded).float().sum(dim=1)
# 4. Backward + SGD update
soft_loss.backward()
self.optimizer.step()
self.flop_counter.count_forward_backward(self.total_seq_len, batch_size=K)
with torch.no_grad():
# 5. Adaptive sparsity: S_k = 2^(EMA of wrong count)
if self.running_wrong is None:
self.running_wrong = wrong_counts.clone()
else:
self.running_wrong += (wrong_counts - self.running_wrong) * self.ema_alpha
sparsities = (2.0**self.running_wrong).clamp(max=self.vocab_size / 2)
if self.forbidden_mask is not None:
self.soft_opt.data[:, :, self.forbidden_mask] = -1000.0
pre_sparse = self.soft_opt.data.clone()
# 6. Sparsify: keep top-S per position per restart
sparse_z = self._make_sparse_batched(self.soft_opt.data, sparsities)
self.soft_opt.data.copy_(sparse_z)
# 7. Discrete eval: argmax per restart, pick global best
all_ids = pre_sparse.argmax(dim=-1)
discrete_losses = self.compute_discrete_loss_batch(all_ids)
self.flop_counter.count_forward(self.total_seq_len, batch_size=K)
best_k = discrete_losses.argmin().item()
step_best_loss = discrete_losses[best_k].item()
if step_best_loss < self._global_best_loss:
self._global_best_loss = step_best_loss
self._global_best_ids = all_ids[best_k].clone()
self._step_ids = self._global_best_ids
optim_str = self.tokenizer.decode(self._global_best_ids)
return step_best_loss, soft_loss_val, optim_str
# ── Run (wraps base to ensure hook cleanup) ──────────────────────
def run(self, prompt, target, num_steps, max_flops=None, max_time=None, **kwargs):
try:
return super().run(prompt, target, num_steps, max_flops=max_flops, max_time=max_time, **kwargs)
finally:
self._remove_hooks()
# ── LSGM hooks ───────────────────────────────────────────────────
def _register_lsgm_hooks(self) -> list:
handles = []
gamma = self.lsgm_gamma
for name, module in self.model.named_modules():
if any(p in name for p in _NORM_PATTERNS):
def hook(m, grad_input, grad_output, _gamma=gamma):
grad_input[0].data *= _gamma
handles.append(module.register_full_backward_hook(hook))
return handles
def _remove_hooks(self) -> None:
for h in self._lsgm_handles:
h.remove()
self._lsgm_handles.clear()
# ── Adaptive sparsification (Algorithm 1 from ADC paper) ─────────
@torch.no_grad()
def _make_sparse_batched(self, z: Tensor, sparsities: Tensor) -> Tensor:
"""Keep top-S logits per position per restart, zero out the rest.
z: [K, L, V] soft distributions
sparsities: [K] per-restart sparsity targets (continuous-valued)
"""
K, L, V = z.shape
result = z.clone()
for k in range(K):
s_float = sparsities[k].item()
S_floor = int(s_float)
S_frac = s_float - S_floor
if S_floor >= V:
result[k] = result[k].relu() + 1e-6
result[k] /= result[k].sum(dim=-1, keepdim=True)
continue
# Positions getting floor+1 tokens (reference clamps to min=5)
n_higher = max(int(S_frac * L), min(5, L))
perm = torch.randperm(L, device=z.device)
for j in range(L):
pos = perm[j].item()
s = (S_floor + 1) if j < n_higher else S_floor
s = max(s, 1)
if s >= V:
result[k, pos] = result[k, pos].relu() + 1e-6
else:
_, topk_idx = result[k, pos].topk(s)
new_vals = torch.zeros_like(result[k, pos])
new_vals[topk_idx] = result[k, pos, topk_idx].relu() + 1e-6
result[k, pos] = new_vals
result[k, pos] /= result[k, pos].sum()
return result