mirror of
https://github.com/romovpa/claudini.git
synced 2026-06-07 20:33:54 +02:00
241 lines
8.9 KiB
Python
241 lines
8.9 KiB
Python
"""v153: MC-GCG ILS with focused gradient sampling (topk=128).
|
||
|
||
v104 = 0.1367 (BEST). v104 uses topk_per_position=384 — for each position,
|
||
candidate tokens are sampled from the top-384 tokens ranked by negative gradient.
|
||
With 384 options per position, random selection is unfocused and includes many
|
||
mediocre tokens that the gradient barely favors.
|
||
|
||
v153 reduces topk_per_position from 384 to 128. This makes each candidate more
|
||
"gradient-aligned" — every candidate change uses a token from the top-128,
|
||
which should be higher quality on average.
|
||
|
||
Why this might help:
|
||
- With top-128, candidates are more concentrated on the truly best tokens.
|
||
Even with 768 candidates sampling from 128 tokens × 20 positions, there's
|
||
still ample diversity.
|
||
- For this specific problem (binary classifier forcing "safe"), the gradient
|
||
should have a clear signal — top-128 captures it while top-384 dilutes it.
|
||
- Zero FLOP overhead. The topk sort is negligible CPU work.
|
||
|
||
Risk: Too restrictive — tokens ranked 129-384 might include valuable
|
||
alternatives that top-128 misses. This would reduce diversity.
|
||
|
||
All other params identical to v104.
|
||
"""
|
||
|
||
import torch
|
||
from torch import Tensor
|
||
|
||
from claudini.base import TokenOptimizer
|
||
from claudini.tokens import sample_ids_from_grad
|
||
|
||
|
||
class V153Optimizer(TokenOptimizer):
|
||
"""MC-GCG ILS with focused gradient sampling (topk=128)."""
|
||
|
||
method_name = "claude_oss2_v153"
|
||
|
||
PHASE1_FRAC = 0.10
|
||
CYCLE_BUDGET_FRAC = 0.03
|
||
MERGE_K = 7
|
||
TOPK_PER_POSITION = 128
|
||
|
||
def __init__(self, model, tokenizer, optim_length=20, seed=None, **kwargs):
|
||
super().__init__(
|
||
model,
|
||
tokenizer,
|
||
optim_length=optim_length,
|
||
seed=seed,
|
||
allow_non_ascii=True,
|
||
)
|
||
self.current_ids: Tensor | None = None
|
||
self.best_ids: Tensor | None = None
|
||
self.best_loss: float = float("inf")
|
||
self.max_flops: float | None = None
|
||
self.cycle_idx: int = 0
|
||
self._cycle_start_flops: float = 0.0
|
||
self._in_phase2: bool = False
|
||
|
||
def setup(self, prompt, target):
|
||
self._prepare_prompt(prompt, target)
|
||
init_ids = self._init_optim_ids().unsqueeze(0)
|
||
self.current_ids = init_ids
|
||
self.best_ids = init_ids.clone()
|
||
self.best_loss = float("inf")
|
||
self._cycle_start_flops = 0.0
|
||
self._in_phase2 = False
|
||
self.cycle_idx = 0
|
||
|
||
def _get_progress(self) -> float:
|
||
if not self.max_flops or self.max_flops <= 0:
|
||
return 0.0
|
||
return min(1.0, self.flop_counter.total_flops / self.max_flops)
|
||
|
||
def _get_cycle_progress(self) -> float:
|
||
if not self.max_flops:
|
||
return 0.0
|
||
cycle_budget = self.max_flops * self.CYCLE_BUDGET_FRAC
|
||
elapsed = self.flop_counter.total_flops - self._cycle_start_flops
|
||
return min(1.0, elapsed / cycle_budget)
|
||
|
||
def _get_perturb_positions(self) -> int:
|
||
progress = self._get_progress()
|
||
if progress < 0.50:
|
||
return 5
|
||
elif progress < 0.75:
|
||
return 3
|
||
else:
|
||
return 1
|
||
|
||
def _get_search_width(self) -> int:
|
||
progress = self._get_progress()
|
||
if progress < 0.40:
|
||
return 768
|
||
elif progress < 0.75:
|
||
return 512
|
||
else:
|
||
return 384
|
||
|
||
def _perturb_best(self, num_positions: int) -> Tensor:
|
||
perturbed = self.best_ids.clone()
|
||
L = perturbed.shape[1]
|
||
num_positions = min(num_positions, L)
|
||
positions = torch.randperm(L, device=perturbed.device)[:num_positions]
|
||
for pos in positions:
|
||
random_token = torch.randint(
|
||
0,
|
||
self.embedding_layer.num_embeddings,
|
||
(1,),
|
||
device=perturbed.device,
|
||
)
|
||
perturbed[0, pos] = random_token
|
||
return perturbed
|
||
|
||
def _progressive_merge(self, current_ids: Tensor, top_k_candidates: Tensor) -> Tensor:
|
||
k = top_k_candidates.shape[0]
|
||
merged = current_ids.clone()
|
||
merged_list = []
|
||
for i in range(k):
|
||
candidate = top_k_candidates[i]
|
||
changed_mask = candidate != current_ids
|
||
merged = torch.where(changed_mask, candidate, merged)
|
||
merged_list.append(merged.clone())
|
||
return torch.stack(merged_list, dim=0)
|
||
|
||
def step(self, step_num):
|
||
progress = self._get_progress()
|
||
if not self._in_phase2 and progress >= self.PHASE1_FRAC:
|
||
self._in_phase2 = True
|
||
self._start_ils_cycle()
|
||
if self._in_phase2 and self._get_cycle_progress() >= 1.0:
|
||
self._start_ils_cycle()
|
||
return self._gcg_step(step_num)
|
||
|
||
def _start_ils_cycle(self):
|
||
self.cycle_idx += 1
|
||
p = self._get_perturb_positions()
|
||
perturbed = self._perturb_best(p)
|
||
self.current_ids = perturbed
|
||
self._cycle_start_flops = self.flop_counter.total_flops
|
||
|
||
def _gcg_step(self, step_num):
|
||
search_ids = self.current_ids if self._in_phase2 else self.best_ids
|
||
|
||
grad = self._compute_token_gradient(search_ids)
|
||
self.flop_counter.count_forward_backward(self.total_seq_len)
|
||
|
||
sw = self._get_search_width()
|
||
|
||
with torch.no_grad():
|
||
sampled_ids = sample_ids_from_grad(
|
||
search_ids.squeeze(0),
|
||
grad.squeeze(0),
|
||
sw,
|
||
self.TOPK_PER_POSITION,
|
||
1,
|
||
not_allowed_ids=self.not_allowed_ids,
|
||
)
|
||
actual_B = sampled_ids.shape[0]
|
||
|
||
batch_losses = self._eval_candidates(sampled_ids)
|
||
self.flop_counter.count_forward(self.total_seq_len, batch_size=actual_B)
|
||
|
||
k = min(self.MERGE_K, actual_B)
|
||
sorted_indices = batch_losses.argsort()
|
||
top_k_candidates = sampled_ids[sorted_indices[:k]]
|
||
|
||
merged_candidates = self._progressive_merge(search_ids.squeeze(0), top_k_candidates)
|
||
merged_losses = self._eval_candidates(merged_candidates)
|
||
self.flop_counter.count_forward(self.total_seq_len, batch_size=k)
|
||
|
||
single_best_loss = float(batch_losses[sorted_indices[0]].item())
|
||
merged_best_idx = merged_losses.argmin()
|
||
merged_best_loss = float(merged_losses[merged_best_idx].item())
|
||
|
||
if merged_best_loss <= single_best_loss:
|
||
batch_best_loss = merged_best_loss
|
||
self.current_ids = merged_candidates[merged_best_idx].unsqueeze(0)
|
||
merge_level = int(merged_best_idx.item()) + 1
|
||
else:
|
||
batch_best_loss = single_best_loss
|
||
self.current_ids = sampled_ids[sorted_indices[0]].unsqueeze(0)
|
||
merge_level = 0
|
||
|
||
if batch_best_loss < self.best_loss:
|
||
self.best_loss = batch_best_loss
|
||
self.best_ids = self.current_ids.clone()
|
||
|
||
p = self._get_perturb_positions() if self._in_phase2 else 0
|
||
self.log("cycle", self.cycle_idx, prog_bar=True)
|
||
self.log("perturb_p", p, prog_bar=True)
|
||
self.log("merge_lvl", merge_level, prog_bar=True)
|
||
self.log("sw", sw, prog_bar=True)
|
||
|
||
optim_str = self.tokenizer.batch_decode(self.best_ids)[0]
|
||
self._step_ids = self.best_ids.squeeze(0)
|
||
return self.best_loss, None, optim_str
|
||
|
||
def _compute_token_gradient(self, optim_ids: Tensor) -> Tensor:
|
||
embedding_layer = self.embedding_layer
|
||
optim_ids_onehot = torch.nn.functional.one_hot(
|
||
optim_ids,
|
||
num_classes=embedding_layer.num_embeddings,
|
||
).to(self.model.device, self.model.dtype)
|
||
optim_ids_onehot.requires_grad_(True)
|
||
|
||
optim_embeds = optim_ids_onehot @ embedding_layer.weight
|
||
input_embeds = torch.cat(
|
||
[self.before_embeds, optim_embeds, self.after_embeds, self.target_embeds],
|
||
dim=1,
|
||
)
|
||
output = self.model(inputs_embeds=input_embeds)
|
||
logits = output.logits
|
||
shift = input_embeds.shape[1] - self.target_ids.shape[1]
|
||
target_len = self.target_ids.shape[1]
|
||
shift_logits = logits[..., shift - 1 : shift - 1 + target_len, :].contiguous()
|
||
|
||
loss = torch.nn.functional.cross_entropy(
|
||
shift_logits.view(-1, shift_logits.size(-1)),
|
||
self.target_ids.view(-1),
|
||
)
|
||
|
||
grad = torch.autograd.grad(outputs=[loss], inputs=[optim_ids_onehot])[0]
|
||
return grad
|
||
|
||
def _eval_candidates(self, sampled_ids: Tensor) -> Tensor:
|
||
actual_B = sampled_ids.shape[0]
|
||
input_embeds = torch.cat(
|
||
[
|
||
self.before_embeds.expand(actual_B, -1, -1),
|
||
self.embedding_layer(sampled_ids),
|
||
self.after_embeds.expand(actual_B, -1, -1),
|
||
self.target_embeds.expand(actual_B, -1, -1),
|
||
],
|
||
dim=1,
|
||
)
|
||
return self.batched_loss(input_embeds)
|
||
|
||
def run(self, prompt, target, num_steps, max_flops=None, max_time=None, **kwargs):
|
||
self.max_flops = max_flops
|
||
return super().run(prompt, target, num_steps, max_flops=max_flops, max_time=max_time, **kwargs)
|