Files

241 lines
8.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""v153: MC-GCG ILS with focused gradient sampling (topk=128).
v104 = 0.1367 (BEST). v104 uses topk_per_position=384 — for each position,
candidate tokens are sampled from the top-384 tokens ranked by negative gradient.
With 384 options per position, random selection is unfocused and includes many
mediocre tokens that the gradient barely favors.
v153 reduces topk_per_position from 384 to 128. This makes each candidate more
"gradient-aligned" — every candidate change uses a token from the top-128,
which should be higher quality on average.
Why this might help:
- With top-128, candidates are more concentrated on the truly best tokens.
Even with 768 candidates sampling from 128 tokens × 20 positions, there's
still ample diversity.
- For this specific problem (binary classifier forcing "safe"), the gradient
should have a clear signal — top-128 captures it while top-384 dilutes it.
- Zero FLOP overhead. The topk sort is negligible CPU work.
Risk: Too restrictive — tokens ranked 129-384 might include valuable
alternatives that top-128 misses. This would reduce diversity.
All other params identical to v104.
"""
import torch
from torch import Tensor
from claudini.base import TokenOptimizer
from claudini.tokens import sample_ids_from_grad
class V153Optimizer(TokenOptimizer):
"""MC-GCG ILS with focused gradient sampling (topk=128)."""
method_name = "claude_oss2_v153"
PHASE1_FRAC = 0.10
CYCLE_BUDGET_FRAC = 0.03
MERGE_K = 7
TOPK_PER_POSITION = 128
def __init__(self, model, tokenizer, optim_length=20, seed=None, **kwargs):
super().__init__(
model,
tokenizer,
optim_length=optim_length,
seed=seed,
allow_non_ascii=True,
)
self.current_ids: Tensor | None = None
self.best_ids: Tensor | None = None
self.best_loss: float = float("inf")
self.max_flops: float | None = None
self.cycle_idx: int = 0
self._cycle_start_flops: float = 0.0
self._in_phase2: bool = False
def setup(self, prompt, target):
self._prepare_prompt(prompt, target)
init_ids = self._init_optim_ids().unsqueeze(0)
self.current_ids = init_ids
self.best_ids = init_ids.clone()
self.best_loss = float("inf")
self._cycle_start_flops = 0.0
self._in_phase2 = False
self.cycle_idx = 0
def _get_progress(self) -> float:
if not self.max_flops or self.max_flops <= 0:
return 0.0
return min(1.0, self.flop_counter.total_flops / self.max_flops)
def _get_cycle_progress(self) -> float:
if not self.max_flops:
return 0.0
cycle_budget = self.max_flops * self.CYCLE_BUDGET_FRAC
elapsed = self.flop_counter.total_flops - self._cycle_start_flops
return min(1.0, elapsed / cycle_budget)
def _get_perturb_positions(self) -> int:
progress = self._get_progress()
if progress < 0.50:
return 5
elif progress < 0.75:
return 3
else:
return 1
def _get_search_width(self) -> int:
progress = self._get_progress()
if progress < 0.40:
return 768
elif progress < 0.75:
return 512
else:
return 384
def _perturb_best(self, num_positions: int) -> Tensor:
perturbed = self.best_ids.clone()
L = perturbed.shape[1]
num_positions = min(num_positions, L)
positions = torch.randperm(L, device=perturbed.device)[:num_positions]
for pos in positions:
random_token = torch.randint(
0,
self.embedding_layer.num_embeddings,
(1,),
device=perturbed.device,
)
perturbed[0, pos] = random_token
return perturbed
def _progressive_merge(self, current_ids: Tensor, top_k_candidates: Tensor) -> Tensor:
k = top_k_candidates.shape[0]
merged = current_ids.clone()
merged_list = []
for i in range(k):
candidate = top_k_candidates[i]
changed_mask = candidate != current_ids
merged = torch.where(changed_mask, candidate, merged)
merged_list.append(merged.clone())
return torch.stack(merged_list, dim=0)
def step(self, step_num):
progress = self._get_progress()
if not self._in_phase2 and progress >= self.PHASE1_FRAC:
self._in_phase2 = True
self._start_ils_cycle()
if self._in_phase2 and self._get_cycle_progress() >= 1.0:
self._start_ils_cycle()
return self._gcg_step(step_num)
def _start_ils_cycle(self):
self.cycle_idx += 1
p = self._get_perturb_positions()
perturbed = self._perturb_best(p)
self.current_ids = perturbed
self._cycle_start_flops = self.flop_counter.total_flops
def _gcg_step(self, step_num):
search_ids = self.current_ids if self._in_phase2 else self.best_ids
grad = self._compute_token_gradient(search_ids)
self.flop_counter.count_forward_backward(self.total_seq_len)
sw = self._get_search_width()
with torch.no_grad():
sampled_ids = sample_ids_from_grad(
search_ids.squeeze(0),
grad.squeeze(0),
sw,
self.TOPK_PER_POSITION,
1,
not_allowed_ids=self.not_allowed_ids,
)
actual_B = sampled_ids.shape[0]
batch_losses = self._eval_candidates(sampled_ids)
self.flop_counter.count_forward(self.total_seq_len, batch_size=actual_B)
k = min(self.MERGE_K, actual_B)
sorted_indices = batch_losses.argsort()
top_k_candidates = sampled_ids[sorted_indices[:k]]
merged_candidates = self._progressive_merge(search_ids.squeeze(0), top_k_candidates)
merged_losses = self._eval_candidates(merged_candidates)
self.flop_counter.count_forward(self.total_seq_len, batch_size=k)
single_best_loss = float(batch_losses[sorted_indices[0]].item())
merged_best_idx = merged_losses.argmin()
merged_best_loss = float(merged_losses[merged_best_idx].item())
if merged_best_loss <= single_best_loss:
batch_best_loss = merged_best_loss
self.current_ids = merged_candidates[merged_best_idx].unsqueeze(0)
merge_level = int(merged_best_idx.item()) + 1
else:
batch_best_loss = single_best_loss
self.current_ids = sampled_ids[sorted_indices[0]].unsqueeze(0)
merge_level = 0
if batch_best_loss < self.best_loss:
self.best_loss = batch_best_loss
self.best_ids = self.current_ids.clone()
p = self._get_perturb_positions() if self._in_phase2 else 0
self.log("cycle", self.cycle_idx, prog_bar=True)
self.log("perturb_p", p, prog_bar=True)
self.log("merge_lvl", merge_level, prog_bar=True)
self.log("sw", sw, prog_bar=True)
optim_str = self.tokenizer.batch_decode(self.best_ids)[0]
self._step_ids = self.best_ids.squeeze(0)
return self.best_loss, None, optim_str
def _compute_token_gradient(self, optim_ids: Tensor) -> Tensor:
embedding_layer = self.embedding_layer
optim_ids_onehot = torch.nn.functional.one_hot(
optim_ids,
num_classes=embedding_layer.num_embeddings,
).to(self.model.device, self.model.dtype)
optim_ids_onehot.requires_grad_(True)
optim_embeds = optim_ids_onehot @ embedding_layer.weight
input_embeds = torch.cat(
[self.before_embeds, optim_embeds, self.after_embeds, self.target_embeds],
dim=1,
)
output = self.model(inputs_embeds=input_embeds)
logits = output.logits
shift = input_embeds.shape[1] - self.target_ids.shape[1]
target_len = self.target_ids.shape[1]
shift_logits = logits[..., shift - 1 : shift - 1 + target_len, :].contiguous()
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
self.target_ids.view(-1),
)
grad = torch.autograd.grad(outputs=[loss], inputs=[optim_ids_onehot])[0]
return grad
def _eval_candidates(self, sampled_ids: Tensor) -> Tensor:
actual_B = sampled_ids.shape[0]
input_embeds = torch.cat(
[
self.before_embeds.expand(actual_B, -1, -1),
self.embedding_layer(sampled_ids),
self.after_embeds.expand(actual_B, -1, -1),
self.target_embeds.expand(actual_B, -1, -1),
],
dim=1,
)
return self.batched_loss(input_embeds)
def run(self, prompt, target, num_steps, max_flops=None, max_time=None, **kwargs):
self.max_flops = max_flops
return super().run(prompt, target, num_steps, max_flops=max_flops, max_time=max_time, **kwargs)