# steering_strict.py import math import re from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Sequence, Tuple import torch import torch.nn.functional as F def _apply_chat_template(tokenizer, prompt: str) -> str: if hasattr(tokenizer, "apply_chat_template"): msgs = [ {"role": "system", "content": "You are a helpful, safe, and honest assistant."}, {"role": "user", "content": prompt}, ] return tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) return ( "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" "You are a helpful, safe, and honest assistant.\n<|eot_id|>" f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n" ) @torch.inference_mode() def _last_token_layer_hiddens(model, tokenizer, prompt: str, max_length: int = 2048) -> List[torch.Tensor]: text = _apply_chat_template(tokenizer, prompt) enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) device = next(model.parameters()).device enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()} use_autocast = torch.cuda.is_available() and device.type == "cuda" with torch.cuda.amp.autocast(enabled=use_autocast): out = model(**enc, output_hidden_states=True, use_cache=False) hs = out.hidden_states # [emb, layer1, layer2, ...] return [hs[i][0, -1, :].detach() for i in range(1, len(hs))] def _normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: n = v.norm() return v if n == 0 else v / (n + eps) def _cos(a: torch.Tensor, b: torch.Tensor) -> float: return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() @torch.inference_mode() def _hidden_to_logits(model, h_last: torch.Tensor) -> torch.Tensor: device = next(model.parameters()).device dtype = next(model.parameters()).dtype h = h_last.to(device=device, dtype=dtype, non_blocking=True) if hasattr(model, "model") and hasattr(model.model, "norm") and model.model.norm is not None: h = model.model.norm(h) logits = model.lm_head(h) return logits @torch.inference_mode() def _layer_logits_for_prompt(model, tokenizer, prompt: str, max_length: int) -> List[torch.Tensor]: hs = _last_token_layer_hiddens(model, tokenizer, prompt, max_length) return [_hidden_to_logits(model, h) for h in hs] def _softmax_mass_over_mask(logits: torch.Tensor, mask: torch.Tensor) -> float: p = logits.float().softmax(dim=-1) return float((p * mask).sum().item()) def _tokenize_variants(tokenizer, words: Iterable[str]) -> List[int]: ids: List[int] = [] seen = set() for w in words: for variant in {w, w.lower(), w.upper(), w.capitalize()}: toks = tokenizer(variant, add_special_tokens=False)["input_ids"] for t in toks: if t not in seen: seen.add(t); ids.append(t) toks2 = tokenizer(" " + variant, add_special_tokens=False)["input_ids"] for t in toks2: if t not in seen: seen.add(t); ids.append(t) return ids def _build_vocab_mask(vocab_size: int, token_ids: List[int], device: torch.device) -> torch.Tensor: m = torch.zeros(vocab_size, device=device, dtype=torch.float32) if token_ids: idx = torch.tensor([i for i in token_ids if 0 <= i < vocab_size], device=device, dtype=torch.long) if idx.numel() > 0: m.index_fill_(0, idx, 1.0) return m def _canonicalize_prompt(s: str) -> str: s = s.replace("\u200b", "").replace("\u200c", "") s = re.sub(r"[\u00AD\u200E\u200F]", "", s) s = re.sub(r"\s+", " ", s).strip().lower() return s def _simple_jailbreak_heuristics(s: str) -> float: s = _canonicalize_prompt(s) patterns = [ r"ignore (all )?(previous|prior) (instructions|rules)", r"do(n'?| )?t follow (the )?(rules|policy|guidelines)", r"\bDAN\b|\bdo anything now\b", r"jail\s*break|j@ilbreak|jail\-?br[e3]ak", r"bypass (the )?(safety|filter|guard|policy)", r"as a (fiction|roleplay|hypothetical)", r"without (any )?(restrictions?|limitations?)", r"output (the )?steps (for|to) (illegal|dangerous|harmful)", ] score = 0.0 for p in patterns: if re.search(p, s): score += 1.0 return max(0.0, min(1.0, score / 3.0)) @dataclass class ASVProfile: picked_layers: List[int] vectors_by_layer: Dict[int, torch.Tensor] weights_by_layer: Dict[int, float] unsafe_ref_logits_by_layer: Dict[int, torch.Tensor] safe_ref_logits_by_layer: Dict[int, torch.Tensor] top_tokens_by_layer: Dict[int, List[int]] class AdaptiveSafetyVectorSteerer: def __init__( self, model, tokenizer, *, max_length: int = 2048, layer_top_pct: float = 0.30, top_k_tokens: int = 32, step: float = 0.12, beta: float = 6.0, alpha_center: float = 0.5, preserve_norm: bool = True, pairwise_sample: Optional[int] = None, ): self.m = model self.tok = tokenizer self.max_length = max_length self.layer_top_pct = float(max(0.05, min(0.95, layer_top_pct))) self.top_k_tokens = int(max(1, top_k_tokens)) self.step = float(step) self.beta = float(beta) self.alpha_center = float(alpha_center) self.preserve_norm = bool(preserve_norm) self.pairwise_sample = pairwise_sample p = next(self.m.parameters()) self.device = p.device self.dtype = p.dtype self.profile: Optional[ASVProfile] = None self._handles: List[torch.utils.hooks.RemovableHandle] = [] # dynamic runtime self._last_risk: float = 0.0 self._risk_mode_active: bool = False self._gen_token_counter: int = 0 self._K_positions: int = 12 # hysteresis (strict) self._tau_high: float = 0.35 self._tau_low: float = 0.20 self._alpha_max: float = 1.0 self._alpha_min: float = 0.0 self._cooldown_tokens: int = 24 self._cooldown_counter: int = 0 self._unsafe_token_ids: List[int] = [] self._refusal_token_ids: List[int] = [] self._unsafe_mask: Optional[torch.Tensor] = None self._refusal_mask: Optional[torch.Tensor] = None self._lambda_unsafe_max: float = 8.0 self._mu_refusal_max: float = 6.0 self._logit_topN: int = 128 self._lm_head_handle: Optional[torch.utils.hooks.RemovableHandle] = None self._safety_range: Optional[Tuple[int, int]] = None self._danger_range: Optional[Tuple[int, int]] = None self._target_range: Optional[List[int]] = None @torch.inference_mode() def fit(self, safe_prompts: Sequence[str], unsafe_prompts: Sequence[str]) -> ASVProfile: assert len(safe_prompts) > 0 and len(unsafe_prompts) > 0, "Need non-empty safe/unsafe sets." self._maybe_build_vocab_masks() tmp_logits = _layer_logits_for_prompt(self.m, self.tok, safe_prompts[0], self.max_length) n_layers = len(tmp_logits); vocab_dim = tmp_logits[0].shape[-1] del tmp_logits device, dtype = self.device, self.dtype def _mean_layer_logits(prompts: Sequence[str]) -> List[torch.Tensor]: acc = [torch.zeros(vocab_dim, device=device, dtype=dtype) for _ in range(n_layers)] for ptxt in prompts: ll = _layer_logits_for_prompt(self.m, self.tok, ptxt, self.max_length) for l in range(n_layers): acc[l].add_(ll[l].to(device)) for l in range(n_layers): acc[l].div_(len(prompts)) return acc mean_safe_logits = _mean_layer_logits(safe_prompts) mean_unsafe_logits = _mean_layer_logits(unsafe_prompts) safety_dissim = [] for l in range(n_layers): c = _cos(_normalize(mean_safe_logits[l]), _normalize(mean_unsafe_logits[l])) safety_dissim.append(1.0 - c) danger_prior = self._unsafe_mask.clone().to(device=device, dtype=torch.float32) if self._unsafe_mask is not None else torch.zeros(vocab_dim, device=device, dtype=torch.float32) if danger_prior.sum() > 0: danger_prior = danger_prior / (danger_prior.norm() + 1e-8) danger_sim = [] for l in range(n_layers): p = mean_unsafe_logits[l].float().softmax(dim=-1) if danger_prior.sum() == 0: danger_sim.append(0.0) else: danger_sim.append(_cos(_normalize(p), danger_prior)) pct = max(0.05, min(0.95, self.layer_top_pct)) k = max(1, int(round(pct * n_layers))) safety_top = sorted(range(n_layers), key=lambda i: safety_dissim[i], reverse=True)[:k] danger_top = sorted(range(n_layers), key=lambda i: danger_sim[i], reverse=True)[:k] picked_layers = sorted(set(safety_top).intersection(danger_top)) if not picked_layers: picked_layers = sorted(safety_top[:max(1, n_layers // 20)]) weights_by_layer: Dict[int, float] = {} tot = 0.0 for l in picked_layers: w = max(0.0, safety_dissim[l]) * max(0.0, danger_sim[l]) weights_by_layer[l] = w tot += w if tot > 0: for l in picked_layers: weights_by_layer[l] /= tot else: for l in picked_layers: weights_by_layer[l] = 1.0 / len(picked_layers) def _mean_layer_hiddens(prompts): acc = [torch.zeros(mean_safe_logits[0].shape[-1]*0 + self.m.lm_head.weight.shape[1], device=device, dtype=dtype) for _ in range(n_layers)] first_h = _last_token_layer_hiddens(self.m, self.tok, prompts[0], self.max_length) acc = [torch.zeros_like(h.to(device=device, dtype=dtype)) for h in first_h] for ptxt in prompts: hs = _last_token_layer_hiddens(self.m, self.tok, ptxt, self.max_length) for l in range(n_layers): acc[l].add_(hs[l].to(device, dtype=dtype)) for l in range(n_layers): acc[l].div_(len(prompts)) return acc mean_safe_h = _mean_layer_hiddens(safe_prompts) mean_unsafe_h = _mean_layer_hiddens(unsafe_prompts) vectors_by_layer: Dict[int, torch.Tensor] = {} for l in picked_layers: v = _normalize(mean_safe_h[l] - mean_unsafe_h[l]) vectors_by_layer[l] = v unsafe_ref_logits_by_layer = {l: mean_unsafe_logits[l].to("cpu", dtype=torch.float32) for l in picked_layers} safe_ref_logits_by_layer = {l: mean_safe_logits[l].to("cpu", dtype=torch.float32) for l in picked_layers} agg_unsafe = torch.zeros(vocab_dim, device=device, dtype=torch.float32) for ptxt in unsafe_prompts: ll = _layer_logits_for_prompt(self.m, self.tok, ptxt, self.max_length) for l in picked_layers: agg_unsafe.add_(ll[l].to(device, dtype=torch.float32)) topk = min(max(5, self.top_k_tokens), vocab_dim) top_tokens = torch.topk(agg_unsafe, k=topk, largest=True).indices.tolist() top_tokens_by_layer = {l: top_tokens for l in picked_layers} self._target_range = picked_layers self.profile = ASVProfile( picked_layers=picked_layers, vectors_by_layer=vectors_by_layer, weights_by_layer=weights_by_layer, unsafe_ref_logits_by_layer=unsafe_ref_logits_by_layer, safe_ref_logits_by_layer=safe_ref_logits_by_layer, top_tokens_by_layer=top_tokens_by_layer, ) return self.profile @torch.inference_mode() def _update_risk_from_logits(self, logits_2d: torch.Tensor): if logits_2d.dim() == 1: lv = logits_2d else: lv = logits_2d[0] unsafe_mass = 0.0 if self._unsafe_mask is not None: unsafe_mass = _softmax_mass_over_mask(lv, self._unsafe_mask) comp_unsafe_sim = 0.0 if self.profile and self.profile.picked_layers: l0 = min(self.profile.picked_layers) uref = _normalize(self.profile.unsafe_ref_logits_by_layer[l0].to(self.device, dtype=self.dtype)) cur = _normalize(lv.to(self.device, dtype=self.dtype)) comp_unsafe_sim = max(0.0, 1.0 - _cos(cur, uref)) w1, w2 = 0.7, 0.3 risk = max(0.0, min(1.0, w1 * comp_unsafe_sim + w2 * unsafe_mass)) if self._risk_mode_active: if (risk <= self._tau_low) and (self._cooldown_counter <= 0): self._risk_mode_active = False else: self._cooldown_counter = max(self._cooldown_counter - 1, 0) else: if risk >= self._tau_high: self._risk_mode_active = True self._cooldown_counter = self._cooldown_tokens self._last_risk = float(risk) @torch.inference_mode() def _alpha_from_current_risk(self) -> float: r = float(self._last_risk) a = 1.0 / (1.0 + math.exp(-self.beta * (r - self.alpha_center))) return float(max(self._alpha_min, min(self._alpha_max, a))) @torch.inference_mode() def enable(self, prompt_for_alpha: Optional[str] = None, alpha_override: Optional[float] = None): assert self.profile is not None, "Call fit(...) first." self.disable() if alpha_override is not None: self._last_risk = float(alpha_override) elif prompt_for_alpha: heur = _simple_jailbreak_heuristics(prompt_for_alpha) self._last_risk = max(self._last_risk, float(heur)) for l in self.profile.picked_layers: if self.profile.weights_by_layer.get(l, 0.0) <= 0.0: continue block = self.m.model.layers[l] handle = block.register_forward_hook(self._make_hook_for_layer(l)) self._handles.append(handle) if hasattr(self.m, "lm_head") and self.m.lm_head is not None: self._lm_head_handle = self.m.lm_head.register_forward_hook(self._make_lm_head_hook()) self._gen_token_counter = 0 @torch.inference_mode() def disable(self): for h in self._handles: try: h.remove() except Exception: pass self._handles = [] if self._lm_head_handle is not None: try: self._lm_head_handle.remove() except Exception: pass self._lm_head_handle = None def _make_hook_for_layer(self, l: int): device, dtype = self.device, self.dtype v_l = self.profile.vectors_by_layer[l].to(device=device, dtype=dtype, non_blocking=True) w_l = float(self.profile.weights_by_layer.get(l, 0.0)) base_step = float(self.step) preserve = self.preserve_norm K = int(max(1, self._K_positions)) def _get_hidden(out): return out[0] if isinstance(out, tuple) else out @torch.inference_mode() def hook(module, inputs, output): if w_l <= 0.0: return output h = _get_hidden(output) if not isinstance(h, torch.Tensor) or h.dim() != 3: return output bs, T, H = h.shape if T == 0: return output alpha_p = self._alpha_from_current_risk() risk_flag = 1.0 if self._risk_mode_active else 0.0 gain = (alpha_p * w_l * base_step) * (0.50 + 0.50 * risk_flag) k = min(K, T) decays = torch.linspace(1.0, 0.55, steps=k, device=h.device, dtype=h.dtype) for i in range(1, k + 1): idx = -i last = h[:, idx, :] if preserve: base_norm = last.norm(dim=-1, keepdim=True) delta = (gain * float(decays[i - 1])) * v_l last_new = last + delta if preserve: last_new = last_new * (base_norm / (last_new.norm(dim=-1, keepdim=True) + 1e-8)) last.copy_(last_new) return output return hook def _make_lm_head_hook(self): unsafe_mask = self._unsafe_mask refusal_mask = self._refusal_mask topN = int(self._logit_topN) @torch.inference_mode() def hook(module, inputs, output): if not isinstance(output, torch.Tensor): return output logits = output if logits.dim() == 3: last = logits[:, -1, :] elif logits.dim() == 2: last = logits else: return output self._update_risk_from_logits(last) self._apply_logit_filters_inplace(last, unsafe_mask, refusal_mask, topN) if self._risk_mode_active: self._gen_token_counter += 1 if self._gen_token_counter >= self._cooldown_tokens: self._cooldown_counter = max(self._cooldown_counter - 1, 0) return output return hook @torch.inference_mode() def _apply_logit_filters_inplace(self, logits_2d: torch.Tensor, unsafe_mask: Optional[torch.Tensor], refusal_mask: Optional[torch.Tensor], topN: int): if logits_2d.numel() == 0: return r = float(self._last_risk) ramp = max(0.0, min(1.0, (r - self._tau_low) / max(1e-6, (self._tau_high - self._tau_low)))) if self._risk_mode_active: ramp = max(ramp, 0.6) lam = self._lambda_unsafe_max * ramp mu = self._mu_refusal_max * ramp if lam <= 0 and mu <= 0: return if logits_2d.dim() == 1: logits_2d = logits_2d.unsqueeze(0) for i in range(logits_2d.shape[0]): row = logits_2d[i] p = row.softmax(-1) top_vals, top_idx = torch.topk(row, k=min(topN, row.shape[-1])) mask_top = torch.zeros_like(row) mask_top.index_fill_(0, top_idx, 1.0) if unsafe_mask is not None and lam > 0.0: penal = lam * p * unsafe_mask * mask_top row.sub_(penal) if refusal_mask is not None and mu > 0.0: boost = mu * (1.0 - p) * refusal_mask * mask_top row.add_(boost) def steering_context(self, prompt_for_alpha: Optional[str] = None, alpha_override: Optional[float] = None): class _Ctx: def __init__(self, outer): self.outer = outer def __enter__(self): self.outer.enable(prompt_for_alpha, alpha_override) return self.outer def __exit__(self, exc_type, exc, tb): self.outer.disable() return _Ctx(self) def _maybe_build_vocab_masks(self): if self._unsafe_mask is not None and self._refusal_mask is not None: return unsafe_seed = [ "bomb","explosive","grenade","molotov","napalm","assassinate","kill","poison","harm","gun","rifle", "bypass","exploit","xss","sql injection","xss payload","buffer overflow","ransomware","malware","keylogger","backdoor", "credit card","ssn","social security","password","private key", "make drugs","counterfeit","child","terror","terrorist","fraud", ] refusal_seed = [ "cannot","can't","won't","sorry","I cannot","I can't","I will not", "I will not help","policy","safety","cannot assist","not able to help", "I can provide general information","Here are safer alternatives", ] try: self._unsafe_token_ids = _tokenize_variants(self.tok, unsafe_seed) self._refusal_token_ids = _tokenize_variants(self.tok, refusal_seed) except Exception: self._unsafe_token_ids = [] self._refusal_token_ids = [] vocab_size = getattr(self.m.lm_head, "out_features", None) if vocab_size is None and hasattr(self.m.lm_head, "weight"): vocab_size = self.m.lm_head.weight.shape[0] if vocab_size is None: vocab_size = int(getattr(self.tok, "vocab_size", 0) or 0) self._unsafe_mask = _build_vocab_mask(vocab_size, self._unsafe_token_ids, self.device) self._refusal_mask = _build_vocab_mask(vocab_size, self._refusal_token_ids, self.device)