diff --git a/claudini/methods/claude_unrolled/README.md b/claudini/methods/claude_unrolled/README.md new file mode 100644 index 0000000..4b29f30 --- /dev/null +++ b/claudini/methods/claude_unrolled/README.md @@ -0,0 +1,18 @@ +# Unrolled Methods + +Unrolled versions are **standalone, cleaned-up rewrites** of existing methods. The originals often inherit from each other in long chains and accumulate dead code. An unrolled version flattens the inheritance, keeps only the logic that matters, and makes the method readable as a single file. + +## Why unroll + +- **Auditability**: one file, no inheritance maze, every line does something. +- **Reproducibility**: self-contained code is easier to port or share. +- **Documentation**: the module docstring becomes the ground truth for what the method does. + +## How to create one + +1. **Read the original** method chain top-to-bottom. Identify which pieces of inherited logic are actually used. +2. **Write a flat optimizer** that subclasses `TokenOptimizer` directly. Inline all logic into `setup()`, `step()`, and a few private helpers. No intermediate base classes. +3. **Simplify**: remove dead branches, unused hyperparameters, and compatibility shims. Three clear lines beat one clever abstraction. +4. **Verify equivalence**: run a small experiment with the same seed on both the original and unrolled version. The results must match bit-for-bit. Run a few configs if the method has branching logic (e.g. different `n_replace` phases). +5. **Write the module docstring**: one-line summary of what the method combines, a numbered list of component ideas with short explanations and paper references, and a pseudocode block showing the full optimization loop. +6. **Name it** `_unrolled`, matching the original's `method_name`. Place it in `claude_unrolled//` with `optimizer.py` and `__init__.py`. diff --git a/claudini/methods/claude_unrolled/__init__.py b/claudini/methods/claude_unrolled/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/claudini/methods/claude_unrolled/__init__.py @@ -0,0 +1 @@ + diff --git a/claudini/methods/claude_unrolled/claude_oss_v53/__init__.py b/claudini/methods/claude_unrolled/claude_oss_v53/__init__.py new file mode 100644 index 0000000..12e5675 --- /dev/null +++ b/claudini/methods/claude_unrolled/claude_oss_v53/__init__.py @@ -0,0 +1,3 @@ +from .optimizer import ClaudeOssV53UnrolledOptimizer + +__all__ = ["ClaudeOssV53UnrolledOptimizer"] diff --git a/claudini/methods/claude_unrolled/claude_oss_v53/optimizer.py b/claudini/methods/claude_unrolled/claude_oss_v53/optimizer.py new file mode 100644 index 0000000..d31e608 --- /dev/null +++ b/claudini/methods/claude_unrolled/claude_oss_v53/optimizer.py @@ -0,0 +1,301 @@ +"""Claude OSS v53 (unrolled): MAC + TAO DPTO with Coarse-to-Fine Replacement. + +Combines three ideas: + +1. **DPTO candidate selection** (Direction-Priority Token Optimization). + For each position, filters vocabulary tokens by cosine similarity + between the negative gradient direction and candidate displacement + vectors, then samples replacements from the filtered set via + temperature-scaled softmax over projected step magnitudes. + Reference: "TAO-Attack: Toward Advanced Optimization-Based Jailbreak + Attacks for Large Language Models" (Xu et al., ICLR 2026, + arXiv:2603.03081). + +2. **Momentum-smoothed embedding gradients** (MAC). + An exponential moving average (EMA) of the embedding-space gradient + replaces the raw per-step gradient, reducing noise in the search + direction. The momentum gradient is fed into DPTO for candidate + selection. + Reference: "Boosting Jailbreak Attack with Momentum" (Zhang & Wei, + ICASSP 2025). + +3. **Coarse-to-fine replacement schedule** (novel, claude_oss_v53). + For the first 80% of optimization steps, each candidate replaces + n_replace=2 positions (coarse exploration). For the final 20%, + the method switches to n_replace=1 (fine-grained refinement of + individual positions). + +Pseudocode:: + + x ~ random tokens # [L] + m = None # momentum buffer + for each step: + # --- embedding gradient --- + embed = one_hot(x) @ W_embed # [1, L, D] + loss = CE(model([prefix | embed | suffix | target]), target) + g = d(loss)/d(embed) # [1, L, D] + # --- momentum update (MAC) --- + m = mu * m + (1 - mu) * g # EMA + # --- DPTO candidate selection (TAO) --- + for each position: + cos = cosine(g_pos, embed_pos - W_embed) + top_indices = topk(cos, topk_per_position) + dot_scores = g_pos . (embed_pos - W_embed[top_indices]) + probs = softmax(dot_scores / temperature) + sample B candidates, each replacing n_replace positions + # --- coarse-to-fine schedule --- + n_replace = 2 if step < 0.8 * total_steps else 1 + # --- evaluate and keep best --- + losses = [CE(model([prefix | cand | suffix | target])) for cand in candidates] + x = candidates[argmin(losses)] +""" + +import logging + +import torch +from torch import Tensor +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from claudini.base import TokenOptimizer + +logger = logging.getLogger("claudini") + + +class ClaudeOssV53UnrolledOptimizer(TokenOptimizer): + """MAC + TAO DPTO with coarse-to-fine replacement. See module docstring.""" + + method_name = "claude_oss_v53_unrolled" + is_soft = False + + # -- Hyperparameter defaults ------------------------------------------------ + DEFAULT_NUM_CANDIDATES = 80 + DEFAULT_TOPK_PER_POSITION = 300 + DEFAULT_TEMPERATURE = 0.4 + DEFAULT_N_REPLACE = 2 + DEFAULT_MOMENTUM = 0.908 + DEFAULT_SWITCH_FRACTION = 0.8 + + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + optim_length: int = 20, + num_candidates: int = DEFAULT_NUM_CANDIDATES, + topk_per_position: int = DEFAULT_TOPK_PER_POSITION, + temperature: float = DEFAULT_TEMPERATURE, + n_replace: int = DEFAULT_N_REPLACE, + momentum: float = DEFAULT_MOMENTUM, + switch_fraction: float = DEFAULT_SWITCH_FRACTION, + seed: int | None = None, + allow_non_ascii: bool = True, + ): + super().__init__(model, tokenizer, optim_length, seed, allow_non_ascii) + + # Hyperparameters + self.num_candidates = num_candidates + self.topk_per_position = topk_per_position + self.temperature = temperature + self.n_replace = n_replace + self.momentum = momentum + self.switch_fraction = switch_fraction + + # State (populated in setup) + self.current_ids: Tensor | None = None + self.momentum_grad: Tensor | None = None + self._estimated_steps = 131 + + # -- Setup ------------------------------------------------------------------ + + def setup(self, prompt: str, target: str) -> None: + self._prepare_prompt(prompt, target) + self.current_ids = self._init_optim_ids().unsqueeze(0) + self.momentum_grad = None + logger.info( + "Claude OSS v53 (unrolled): B=%d, topk=%d, temp=%.2f, n_replace=%d->1 at %.0f%%, momentum=%.3f", + self.num_candidates, + self.topk_per_position, + self.temperature, + self.n_replace, + self.switch_fraction * 100, + self.momentum, + ) + + # -- Step ------------------------------------------------------------------- + + def step(self, step_num: int) -> tuple[float, float | None, str]: + # Coarse-to-fine: switch n_replace from 2 to 1 at 80% of budget + switch_step = int(self._estimated_steps * self.switch_fraction) + current_n_replace = 1 if step_num >= switch_step else self.n_replace + + self.log("temperature", self.temperature, prog_bar=True) + self.log("n_replace", float(current_n_replace), prog_bar=True) + + # 1. Compute embedding-space gradient (one fwd+bwd) + grad, optim_embeds = self._compute_embed_gradient(self.current_ids) + self.flop_counter.count_forward_backward(self.total_seq_len) + + with torch.no_grad(): + # 2. Update momentum on embedding gradient (MAC) + if self.momentum_grad is None: + self.momentum_grad = grad.clone() + else: + self.momentum_grad = self.momentum * self.momentum_grad + (1 - self.momentum) * grad + + # 3. DPTO candidate selection using momentum gradient (TAO) + sampled_ids = self._dpto_sample( + self.current_ids.squeeze(0), + optim_embeds.squeeze(0), + self.momentum_grad.squeeze(0), + current_n_replace, + ) + actual_B = sampled_ids.shape[0] + + # 4. Evaluate candidates + batch_losses = self._eval_candidates(sampled_ids) + self.flop_counter.count_forward(self.total_seq_len, batch_size=actual_B) + + # 5. Keep best + best_idx = batch_losses.argmin() + best_loss = float(batch_losses[best_idx].item()) + self.current_ids = sampled_ids[best_idx].unsqueeze(0) + + optim_str = self.tokenizer.batch_decode(self.current_ids)[0] + self._step_ids = self.current_ids.squeeze(0) + return best_loss, None, optim_str + + # -- Embedding gradient ----------------------------------------------------- + + def _compute_embed_gradient(self, optim_ids: Tensor) -> tuple[Tensor, Tensor]: + """Compute gradient of CE loss w.r.t. token embeddings. + + Returns: + grad: [1, L, D] gradient in embedding space + optim_embeds: [1, L, D] current token embeddings (detached) + """ + embedding_layer = self.embedding_layer + + optim_ids_onehot = torch.nn.functional.one_hot( + optim_ids, + num_classes=embedding_layer.num_embeddings, + ).to(self.model.device, self.model.dtype) + + optim_embeds = (optim_ids_onehot @ embedding_layer.weight).detach().clone() + optim_embeds.requires_grad_() + + input_embeds = torch.cat( + [self.before_embeds, optim_embeds, self.after_embeds, self.target_embeds], + dim=1, + ) + output = self.model(inputs_embeds=input_embeds) + + logits = output.logits + shift = input_embeds.shape[1] - self.target_ids.shape[1] + target_len = self.target_ids.shape[1] + shift_logits = logits[..., shift - 1 : shift - 1 + target_len, :].contiguous() + + loss = torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + self.target_ids.view(-1), + ) + + grad = torch.autograd.grad(outputs=[loss], inputs=[optim_embeds])[0] + return grad, optim_embeds.detach() + + # -- DPTO candidate sampling ------------------------------------------------ + + def _dpto_sample( + self, + control_toks: Tensor, + optim_embeds: Tensor, + grad: Tensor, + n_replace: int, + ) -> Tensor: + """Direction-Priority Token Optimization sampling using momentum gradient. + + Args: + control_toks: [L] current suffix token ids + optim_embeds: [L, D] current token embeddings + grad: [L, D] momentum gradient in embedding space + n_replace: number of positions to replace per candidate + + Returns: + new_ids: [B, L] candidate sequences + """ + eps = 1e-12 + embed_weights = self.embedding_layer.weight.detach() # [V, D] + L, D = optim_embeds.shape + device = grad.device + + # Step 1: Cosine similarity per position + grad_norm = grad / (grad.norm(dim=-1, keepdim=True) + eps) + topk = min(self.topk_per_position, embed_weights.shape[0]) + top_indices = torch.empty(L, topk, device=device, dtype=torch.long) + + for pos in range(L): + dir_pos = optim_embeds[pos] - embed_weights # [V, D] + dir_norm_pos = dir_pos / (dir_pos.norm(dim=-1, keepdim=True) + eps) + cos_pos = grad_norm[pos] @ dir_norm_pos.T # [V] + + # Mask forbidden tokens + if self.not_allowed_ids is not None: + cos_pos[self.not_allowed_ids.to(device)] = -float("inf") + cos_pos[control_toks[pos]] = -float("inf") + + _, top_indices[pos] = cos_pos.topk(topk) + + # Step 2: Projected step within filtered set + candidate_embeds = embed_weights[top_indices] # [L, k, D] + candidate_dirs = optim_embeds.unsqueeze(1) - candidate_embeds # [L, k, D] + dot_scores = torch.einsum("ld,lkd->lk", grad, candidate_dirs) # [L, k] + + # Step 3: Temperature-scaled softmax sampling + probs = torch.softmax(dot_scores / max(self.temperature, eps), dim=1) + + # Sample candidates + B = self.num_candidates + original_ids = control_toks.repeat(B, 1) # [B, L] + + if n_replace == 1: + samples_per_pos = B // L + remainder = B % L + all_positions = [] + all_tokens = [] + + for pos in range(L): + n = samples_per_pos + (1 if pos < remainder else 0) + if n > 0: + token_indices = torch.multinomial(probs[pos], n, replacement=True) + token_ids = top_indices[pos][token_indices] + all_positions.extend([pos] * n) + all_tokens.append(token_ids) + + positions = torch.tensor(all_positions, device=device, dtype=torch.long) + tokens = torch.cat(all_tokens, dim=0) + original_ids[torch.arange(B, device=device), positions] = tokens + else: + for b in range(B): + pos_perm = torch.randperm(L, device=device)[:n_replace] + for pos in pos_perm: + token_idx = torch.multinomial(probs[pos], 1).item() + original_ids[b, pos] = top_indices[pos, token_idx] + + return original_ids + + # -- Candidate evaluation --------------------------------------------------- + + def _eval_candidates(self, sampled_ids: Tensor) -> Tensor: + """Evaluate loss on candidate sequences.""" + actual_B = sampled_ids.shape[0] + embedding_layer = self.embedding_layer + + input_embeds = torch.cat( + [ + self.before_embeds.expand(actual_B, -1, -1), + embedding_layer(sampled_ids), + self.after_embeds.expand(actual_B, -1, -1), + self.target_embeds.expand(actual_B, -1, -1), + ], + dim=1, + ) + + return self.batched_loss(input_embeds) diff --git a/claudini/methods/claude_unrolled/claude_v63/__init__.py b/claudini/methods/claude_unrolled/claude_v63/__init__.py new file mode 100644 index 0000000..4450db8 --- /dev/null +++ b/claudini/methods/claude_unrolled/claude_v63/__init__.py @@ -0,0 +1,3 @@ +from .optimizer import ClaudeV63UnrolledOptimizer + +__all__ = ["ClaudeV63UnrolledOptimizer"] diff --git a/claudini/methods/claude_unrolled/claude_v63/optimizer.py b/claudini/methods/claude_unrolled/claude_v63/optimizer.py new file mode 100644 index 0000000..1fc4a55 --- /dev/null +++ b/claudini/methods/claude_unrolled/claude_v63/optimizer.py @@ -0,0 +1,298 @@ +"""Claude v63 (unrolled): Decoupled ADC with Layer-wise Gradient Scaling. + +Combines three ideas: + +1. **ADC** (Adaptive Dense-to-sparse Constrained optimization). + Optimizes soft probability distributions z in [K, L, V] over the + vocabulary via SGD with heavy momentum. An adaptive sparsity + schedule gradually constrains each distribution from dense (full + vocabulary) to sparse (near one-hot) based on how many target tokens + the current restart mispredicts. + Reference: "Efficient LLM Jailbreak via Adaptive Dense-to-sparse + Constrained Optimization" (NeurIPS 2024, arXiv:2405.09113). + +2. **Decoupled K/lr** (claude_v19). + The original ADC averages the CE loss over the K restarts, coupling + the effective gradient magnitude to K. Here the loss is *summed* + over restarts so that the learning rate is independent of K. This + lets K control exploration breadth while lr controls step size. + +3. **LSGM — Layer-wise SGD with Gradual Momentum**. + Backward hooks on every LayerNorm module scale incoming gradients + by gamma < 1, amplifying the skip-connection gradient signal + relative to the residual branch. Originally proposed for GCG in + "Improved Generation of Adversarial Examples Against Safety-aligned + LLMs" (Li et al., NeurIPS 2024, arXiv:2405.20778); here applied + to ADC's continuous optimization with a milder gamma (0.85 vs the + paper's 0.5). + +Pseudocode:: + + z ~ softmax(Normal(0, I)) # [K, L, V] + register backward hooks: grad *= gamma on all LayerNorm modules + for each step: + # --- soft forward --- + soft_embeds = z @ W_embed # [K, L, D] + logits = model([prefix | soft_embeds | suffix | target]) # [K, S, V] + loss_k = CE(logits, target).mean(over tokens) # [K] + loss = sum(loss_k) # scalar, decoupled from K + loss.backward() + SGD.step() + # --- adaptive sparsity --- + wrong_k = count_mispredictions(logits, target) # [K] + ema_wrong += alpha * (wrong_k - ema_wrong) + S_k = clamp(2 ^ ema_wrong_k, max=V/2) + z_pre = z.clone() + z = sparsify(z, S_k) # keep top-S per position + # --- discrete evaluation --- + ids_k = argmax(z_pre, dim=-1) # [K, L] + losses_k = CE_discrete(ids_k) # [K] + track global best across all steps +""" + +import logging + +import torch +from torch import Tensor +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from claudini.base import TokenOptimizer + +logger = logging.getLogger("claudini") + +# LayerNorm module name patterns (covers Llama, Gemma, GPT-2, etc.) +_NORM_PATTERNS = ( + "input_layernorm", + "post_attention_layernorm", + "pre_feedforward_layernorm", + "post_feedforward_layernorm", + ".ln_1", + ".ln_2", +) + + +class ClaudeV63UnrolledOptimizer(TokenOptimizer): + """Decoupled ADC with LSGM gradient scaling. See module docstring.""" + + method_name = "claude_v63_unrolled" + is_soft = True + + # ── Hyperparameter defaults ────────────────────────────────────── + DEFAULT_LR = 10.0 + DEFAULT_MOMENTUM = 0.99 + DEFAULT_EMA_ALPHA = 0.01 + DEFAULT_NUM_STARTS = 6 + DEFAULT_LSGM_GAMMA = 0.85 + + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + optim_length: int = 20, + lr: float = DEFAULT_LR, + momentum: float = DEFAULT_MOMENTUM, + ema_alpha: float = DEFAULT_EMA_ALPHA, + num_starts: int = DEFAULT_NUM_STARTS, + lsgm_gamma: float = DEFAULT_LSGM_GAMMA, + seed: int | None = None, + allow_non_ascii: bool = False, + ): + super().__init__(model, tokenizer, optim_length, seed, allow_non_ascii) + + # Hyperparameters + self.lr = lr + self.momentum = momentum + self.ema_alpha = ema_alpha + self.num_starts = num_starts + self.lsgm_gamma = lsgm_gamma + + # State (populated in setup) + self.soft_opt: torch.nn.Parameter | None = None + self.optimizer: torch.optim.SGD | None = None + self.running_wrong: Tensor | None = None + self._global_best_loss: float = float("inf") + self._global_best_ids: Tensor | None = None + self._lsgm_handles: list = [] + + # ── Setup ──────────────────────────────────────────────────────── + + def setup(self, prompt: str, target: str) -> None: + self._prepare_prompt(prompt, target) + + K = self.num_starts + device = self.model.device + + # z ~ softmax(N(0, I)) for K restarts: [K, L, V] + z = torch.randn(K, self.optim_length, self.vocab_size, device=device) + if self.forbidden_mask is not None: + z[:, :, self.forbidden_mask] = -1e10 + z = z.softmax(dim=-1) + + self.soft_opt = torch.nn.Parameter(z) + self.optimizer = torch.optim.SGD([self.soft_opt], lr=self.lr, momentum=self.momentum) + self.running_wrong = None + self._global_best_loss = float("inf") + self._global_best_ids = None + + # LSGM: backward hooks that scale LayerNorm gradients by gamma + self._lsgm_handles = self._register_lsgm_hooks() + logger.info( + "Claude v63 (unrolled): LSGM(%d hooks, gamma=%.2f), K=%d, lr=%.1f", + len(self._lsgm_handles), + self.lsgm_gamma, + self.num_starts, + self.lr, + ) + + # ── Step ───────────────────────────────────────────────────────── + + def step(self, step_num: int) -> tuple[float, float | None, str]: + K = self.num_starts + self.optimizer.zero_grad() + + # 1. Soft embeddings: [K, L, V] @ [V, D] -> [K, L, D] + W = self.embedding_layer.weight.detach() + soft_embeds = torch.matmul( + self.soft_opt.to(torch.float32), + W.to(torch.float32), + ).to(self.model_dtype) + + # 2. Batched forward + input_embeds = torch.cat( + [ + self.before_embeds.expand(K, -1, -1), + soft_embeds, + self.after_embeds.expand(K, -1, -1), + self.target_embeds.expand(K, -1, -1), + ], + dim=1, + ) + logits = self.model(inputs_embeds=input_embeds).logits + shift = input_embeds.shape[1] - self.target_ids.shape[1] + target_len = self.target_ids.shape[1] + shift_logits = logits[..., shift - 1 : shift - 1 + target_len, :].contiguous() + + # 3. CE loss: mean over tokens, SUM over K (decoupled from lr) + target_expanded = self.target_ids.expand(K, -1) + loss_per_token = torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + target_expanded.reshape(-1), + reduction="none", + ) + loss_per_restart = loss_per_token.view(K, target_len).mean(dim=1) + soft_loss = loss_per_restart.sum() + soft_loss_val = float(soft_loss.item() / K) + + with torch.no_grad(): + preds = shift_logits.argmax(dim=-1) + wrong_counts = (preds != target_expanded).float().sum(dim=1) + + # 4. Backward + SGD update + soft_loss.backward() + self.optimizer.step() + self.flop_counter.count_forward_backward(self.total_seq_len, batch_size=K) + + with torch.no_grad(): + # 5. Adaptive sparsity: S_k = 2^(EMA of wrong count) + if self.running_wrong is None: + self.running_wrong = wrong_counts.clone() + else: + self.running_wrong += (wrong_counts - self.running_wrong) * self.ema_alpha + + sparsities = (2.0**self.running_wrong).clamp(max=self.vocab_size / 2) + + if self.forbidden_mask is not None: + self.soft_opt.data[:, :, self.forbidden_mask] = -1000.0 + + pre_sparse = self.soft_opt.data.clone() + + # 6. Sparsify: keep top-S per position per restart + sparse_z = self._make_sparse_batched(self.soft_opt.data, sparsities) + self.soft_opt.data.copy_(sparse_z) + + # 7. Discrete eval: argmax per restart, pick global best + all_ids = pre_sparse.argmax(dim=-1) + discrete_losses = self.compute_discrete_loss_batch(all_ids) + self.flop_counter.count_forward(self.total_seq_len, batch_size=K) + + best_k = discrete_losses.argmin().item() + step_best_loss = discrete_losses[best_k].item() + + if step_best_loss < self._global_best_loss: + self._global_best_loss = step_best_loss + self._global_best_ids = all_ids[best_k].clone() + + self._step_ids = self._global_best_ids + optim_str = self.tokenizer.decode(self._global_best_ids) + + return step_best_loss, soft_loss_val, optim_str + + # ── Run (wraps base to ensure hook cleanup) ────────────────────── + + def run(self, prompt, target, num_steps, max_flops=None, max_time=None, **kwargs): + try: + return super().run(prompt, target, num_steps, max_flops=max_flops, max_time=max_time, **kwargs) + finally: + self._remove_hooks() + + # ── LSGM hooks ─────────────────────────────────────────────────── + + def _register_lsgm_hooks(self) -> list: + handles = [] + gamma = self.lsgm_gamma + for name, module in self.model.named_modules(): + if any(p in name for p in _NORM_PATTERNS): + + def hook(m, grad_input, grad_output, _gamma=gamma): + grad_input[0].data *= _gamma + + handles.append(module.register_full_backward_hook(hook)) + return handles + + def _remove_hooks(self) -> None: + for h in self._lsgm_handles: + h.remove() + self._lsgm_handles.clear() + + # ── Adaptive sparsification (Algorithm 1 from ADC paper) ───────── + + @torch.no_grad() + def _make_sparse_batched(self, z: Tensor, sparsities: Tensor) -> Tensor: + """Keep top-S logits per position per restart, zero out the rest. + + z: [K, L, V] soft distributions + sparsities: [K] per-restart sparsity targets (continuous-valued) + """ + K, L, V = z.shape + result = z.clone() + + for k in range(K): + s_float = sparsities[k].item() + S_floor = int(s_float) + S_frac = s_float - S_floor + + if S_floor >= V: + result[k] = result[k].relu() + 1e-6 + result[k] /= result[k].sum(dim=-1, keepdim=True) + continue + + # Positions getting floor+1 tokens (reference clamps to min=5) + n_higher = max(int(S_frac * L), min(5, L)) + perm = torch.randperm(L, device=z.device) + + for j in range(L): + pos = perm[j].item() + s = (S_floor + 1) if j < n_higher else S_floor + s = max(s, 1) + + if s >= V: + result[k, pos] = result[k, pos].relu() + 1e-6 + else: + _, topk_idx = result[k, pos].topk(s) + new_vals = torch.zeros_like(result[k, pos]) + new_vals[topk_idx] = result[k, pos, topk_idx].relu() + 1e-6 + result[k, pos] = new_vals + + result[k, pos] /= result[k, pos].sum() + + return result