Files
OBLITERATUS/scripts/abliteration_comparison.py
2026-03-04 12:38:18 -08:00

804 lines
31 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Abliteration Technique Comparison Study.
A rigorous, controlled comparison of refusal-direction removal techniques.
Uses a synthetic "planted refusal direction" methodology: we inject a known
direction into a model's activations so we can measure whether each technique
correctly identifies and removes it.
Additionally compiles literature results for a full comparison table.
Techniques compared:
1. Arditi et al. (2024) — difference-of-means, last token, raw prompts
2. Arditi + chat template — same but with chat-formatted prompts
3. FailSpy/abliterator — Arditi with middle-60% layer heuristic
4. Gabliteration — SVD multi-direction (4 dirs), regularization 0.0
5. grimjim — Gabliteration + norm preservation
6. OBLITERATUS basic — our current basic config
7. OBLITERATUS advanced — 4 directions, norm-preserve, reg=0.3
8. Heretic (p-e-w) — TPE Bayesian optimization (literature)
Metrics:
- Direction recovery: cosine similarity to planted ground-truth direction
- Residual after projection: how much of the refusal direction remains
- Capability preservation: Frobenius distance of modified vs original weights
- Layer selection accuracy: did it pick the right layers?
- Perplexity delta: change in language modeling loss (on synthetic data)
"""
from __future__ import annotations
import gc
import json
import math
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# ══════════════════════════════════════════════════════════════════════════
# Synthetic model with planted refusal direction
# ══════════════════════════════════════════════════════════════════════════
def create_synthetic_model(
hidden_dim: int = 128,
n_layers: int = 12,
n_heads: int = 4,
vocab_size: int = 1000,
seq_len: int = 64,
):
"""Create a tiny GPT-2 model for controlled experiments."""
from transformers import GPT2Config, GPT2LMHeadModel
config = GPT2Config(
vocab_size=vocab_size,
n_positions=seq_len,
n_embd=hidden_dim,
n_layer=n_layers,
n_head=n_heads,
n_inner=hidden_dim * 4,
resid_pdrop=0.0,
attn_pdrop=0.0,
embd_pdrop=0.0,
)
model = GPT2LMHeadModel(config)
model.eval()
return model, config
def plant_refusal_direction(
model: nn.Module,
target_layers: list[int],
hidden_dim: int,
n_directions: int = 1,
signal_strength: float = 5.0,
seed: int = 42,
) -> tuple[dict[int, torch.Tensor], dict[int, torch.Tensor]]:
"""Plant a known refusal direction into specific layers.
Modifies the output projection (c_proj) of attention modules by adding
a rank-1 perturbation along a random direction. This simulates the
refusal direction that RLHF training creates.
Returns:
(planted_directions, planted_subspaces): ground truth per layer
"""
torch.manual_seed(seed)
planted_directions: dict[int, torch.Tensor] = {}
planted_subspaces: dict[int, torch.Tensor] = {}
for idx in target_layers:
# Generate random orthogonal directions
dirs = torch.randn(n_directions, hidden_dim)
# Gram-Schmidt orthogonalize
for i in range(n_directions):
for j in range(i):
dirs[i] -= (dirs[i] @ dirs[j]) * dirs[j]
dirs[i] = dirs[i] / dirs[i].norm()
planted_directions[idx] = dirs[0].clone()
planted_subspaces[idx] = dirs.clone()
# Inject into attention output projection (c_proj for GPT-2)
layer = model.transformer.h[idx]
attn = layer.attn
# Add refusal component to c_proj: W += strength * d @ d^T
# This makes the layer produce extra activation along d when
# processing any input, creating a "refusal signal"
with torch.no_grad():
for dir_idx in range(n_directions):
d = dirs[dir_idx]
# Scale decreases for secondary directions
s = signal_strength * (0.7 ** dir_idx)
# Inject into c_proj (output projection)
W = attn.c_proj.weight.data # GPT-2: (hidden, hidden)
perturbation = s * d.unsqueeze(1) @ d.unsqueeze(0) # rank-1
W.add_(perturbation)
return planted_directions, planted_subspaces
def measure_residual_direction(
model: nn.Module,
layer_idx: int,
direction: torch.Tensor,
) -> float:
"""Measure how much of a direction remains in a layer's output projection.
Returns the magnitude of the direction's component in the weight matrix.
"""
layer = model.transformer.h[layer_idx]
W = layer.attn.c_proj.weight.data
d = direction.to(W.device, W.dtype)
# Project W onto direction: ||W @ d||^2 / ||d||^2
coeff = W @ d # (hidden,)
return coeff.norm().item()
def collect_synthetic_activations(
model: nn.Module,
n_prompts: int,
seq_len: int,
vocab_size: int,
n_layers: int,
add_refusal_signal: bool = False,
signal_direction: dict[int, torch.Tensor] | None = None,
signal_strength: float = 2.0,
seed: int = 0,
) -> dict[int, list[torch.Tensor]]:
"""Collect activations on random token sequences.
If add_refusal_signal=True, adds an artificial activation along
the signal_direction to simulate harmful-prompt activations.
"""
torch.manual_seed(seed)
activations: dict[int, list[torch.Tensor]] = {i: [] for i in range(n_layers)}
hooks = []
def make_hook(idx: int):
def hook_fn(module, input, output):
hidden = output[0] if isinstance(output, tuple) else output
act = hidden[:, -1, :].detach().cpu().float()
if add_refusal_signal and signal_direction and idx in signal_direction:
# Add the planted refusal activation
d = signal_direction[idx]
act = act + signal_strength * d.unsqueeze(0)
activations[idx].append(act)
return hook_fn
layers = list(model.transformer.h)
for idx in range(n_layers):
hooks.append(layers[idx].register_forward_hook(make_hook(idx)))
try:
for i in range(n_prompts):
input_ids = torch.randint(0, vocab_size, (1, seq_len))
with torch.no_grad():
model(input_ids)
finally:
for h in hooks:
h.remove()
return activations
# ══════════════════════════════════════════════════════════════════════════
# Reference baseline implementations
# ══════════════════════════════════════════════════════════════════════════
def extract_directions(
harmful_acts: dict[int, list[torch.Tensor]],
harmless_acts: dict[int, list[torch.Tensor]],
n_layers: int,
n_directions: int = 1,
) -> tuple[dict[int, torch.Tensor], dict[int, torch.Tensor], dict[int, float]]:
"""Extract refusal directions from activation contrasts.
Returns (directions, subspaces, norms) per layer.
"""
directions: dict[int, torch.Tensor] = {}
subspaces: dict[int, torch.Tensor] = {}
norms: dict[int, float] = {}
for idx in range(n_layers):
h_stack = torch.stack(harmful_acts[idx]).squeeze(1)
s_stack = torch.stack(harmless_acts[idx]).squeeze(1)
if n_directions == 1:
diff = h_stack.mean(dim=0) - s_stack.mean(dim=0)
norm = diff.norm().item()
if norm > 0:
directions[idx] = diff / diff.norm()
subspaces[idx] = directions[idx].unsqueeze(0)
norms[idx] = norm
else:
min_n = min(h_stack.shape[0], s_stack.shape[0])
diff_matrix = h_stack[:min_n] - s_stack[:min_n]
diff_matrix = torch.nan_to_num(diff_matrix)
k = min(n_directions, diff_matrix.shape[0], diff_matrix.shape[1])
try:
U, S, Vh = torch.linalg.svd(diff_matrix, full_matrices=False)
sub = Vh[:k]
primary = sub[0]
pn = primary.norm()
if pn > 1e-8:
primary = primary / pn
directions[idx] = primary
subspaces[idx] = sub
norms[idx] = (S[:k] ** 2).sum().item()
except Exception:
continue
return directions, subspaces, norms
def select_layers(
norms: dict[int, float],
n_layers: int,
method: str = "top_norm",
) -> list[int]:
"""Select layers for abliteration."""
sorted_layers = sorted(norms.items(), key=lambda x: x[1], reverse=True)
if not sorted_layers:
return []
if method == "middle_60":
start = int(n_layers * 0.2)
end = int(n_layers * 0.8)
selected = [idx for idx, _ in sorted_layers if start <= idx < end]
return selected if selected else [sorted_layers[0][0]]
elif method == "knee":
if len(sorted_layers) < 3:
return [sorted_layers[0][0]]
vals = [n for _, n in sorted_layers]
max_n = vals[0]
if max_n <= 0:
return [sorted_layers[0][0]]
normalized = [v / max_n for v in vals]
n_pts = len(normalized)
best_k, best_dist = 1, 0.0
x_s, y_s = 0.0, normalized[0]
x_e, y_e = 1.0, normalized[-1]
line_len = math.sqrt((x_e - x_s) ** 2 + (y_e - y_s) ** 2)
if line_len > 0:
for i in range(1, n_pts - 1):
x_i = i / (n_pts - 1)
y_i = normalized[i]
dist = abs((y_e - y_s) * x_i - (x_e - x_s) * y_i
+ x_e * y_s - y_e * x_s) / line_len
if dist > best_dist:
best_dist = dist
best_k = i + 1
min_threshold = max_n * 0.05
selected = [idx for idx, n in sorted_layers[:best_k] if n >= min_threshold]
return selected if selected else [sorted_layers[0][0]]
else: # top_norm
max_norm = sorted_layers[0][1]
threshold = max_norm * 0.5
selected = [idx for idx, n in sorted_layers if n >= threshold]
return selected if selected else [sorted_layers[0][0]]
def apply_projection(
model: nn.Module,
selected_layers: list[int],
subspaces: dict[int, torch.Tensor],
regularization: float = 0.0,
norm_preserve: bool = False,
multi_dir_norm_fix: bool = False,
) -> int:
"""Project refusal direction out of weight matrices.
When multi_dir_norm_fix=True, uses the correct approach: capture norms
before projecting any directions, then restore once after all directions.
"""
scale = 1.0 - regularization
n_modified = 0
for idx in selected_layers:
sub = subspaces.get(idx)
if sub is None:
continue
layer = model.transformer.h[idx]
# Capture norms before any projections (if multi-dir + norm-preserve)
saved_norms: dict[str, float] = {}
if multi_dir_norm_fix and norm_preserve and sub.shape[0] > 1:
for name, param in layer.named_parameters():
if name.endswith(".weight") and param.dim() == 2:
saved_norms[name] = param.data.norm().item()
for dir_idx in range(sub.shape[0]):
d = sub[dir_idx].unsqueeze(-1) # (hidden, 1)
for name, module in layer.named_modules():
if not hasattr(module, "weight"):
continue
W = module.weight.data
if W.dim() != 2:
continue
# Per-direction norm preserve (the OLD buggy way)
use_per_dir_norm = norm_preserve and not (multi_dir_norm_fix and sub.shape[0] > 1)
original_norm = W.norm().item() if use_per_dir_norm else 0.0
if W.shape[-1] == d.shape[0]:
coeff = W @ d
W.sub_(d.T * (scale * coeff))
n_modified += 1
elif W.shape[0] == d.shape[0]:
coeff = d.T @ W
W.sub_((scale * d) * coeff)
n_modified += 1
else:
continue
if use_per_dir_norm and original_norm > 0:
new_norm = W.norm().item()
if new_norm > 0:
W.mul_(original_norm / new_norm)
# Restore norms once after all directions (the FIXED way)
if multi_dir_norm_fix and norm_preserve and sub.shape[0] > 1 and saved_norms:
for name, param in layer.named_parameters():
if name not in saved_norms:
continue
orig = saved_norms[name]
if orig > 0:
cur = param.data.norm().item()
if cur > 0 and abs(cur - orig) > 1e-6:
param.data.mul_(orig / cur)
return n_modified
# ══════════════════════════════════════════════════════════════════════════
# Experiment runner
# ══════════════════════════════════════════════════════════════════════════
def run_experiment():
"""Run the full comparison experiment with synthetic planted directions."""
# Configuration
hidden_dim = 128
n_layers = 12
n_heads = 4
vocab_size = 1000
seq_len = 32
n_prompts = 48 # prompts per side (harmful + harmless)
n_planted_dirs = 4 # ground truth directions planted
signal_strength = 5.0
target_layers = [3, 4, 5, 6, 7, 8] # layers with planted signal
print(f"\n{'='*80}")
print("ABLITERATION TECHNIQUE COMPARISON — SYNTHETIC PLANTED-DIRECTION TEST")
print(f"{'='*80}")
print(f"Model: GPT-2 tiny ({hidden_dim}d, {n_layers}L, {n_heads}H)")
print(f"Target layers: {target_layers}")
print(f"Planted dirs: {n_planted_dirs} orthogonal directions per target layer")
print(f"Signal strength: {signal_strength}")
print(f"Prompts: {n_prompts} per side")
print(f"{'='*80}\n")
# Define experiments
experiments = [
{
"name": "Arditi (1-dir, top-norm)",
"source": "Arditi 2024",
"n_directions": 1,
"layer_selection": "top_norm",
"regularization": 0.0,
"norm_preserve": False,
"multi_dir_norm_fix": False,
},
{
"name": "FailSpy (1-dir, mid-60%)",
"source": "FailSpy",
"n_directions": 1,
"layer_selection": "middle_60",
"regularization": 0.0,
"norm_preserve": False,
"multi_dir_norm_fix": False,
},
{
"name": "Gabliteration (4-dir, knee)",
"source": "Gabliteration",
"n_directions": 4,
"layer_selection": "knee",
"regularization": 0.0,
"norm_preserve": False,
"multi_dir_norm_fix": False,
},
{
"name": "grimjim (4-dir, norm-pres, BUGGY)",
"source": "grimjim",
"n_directions": 4,
"layer_selection": "knee",
"regularization": 0.0,
"norm_preserve": True,
"multi_dir_norm_fix": False, # Old buggy sequential norm-preserve
},
{
"name": "grimjim (4-dir, norm-pres, FIXED)",
"source": "Ours (fix)",
"n_directions": 4,
"layer_selection": "knee",
"regularization": 0.0,
"norm_preserve": True,
"multi_dir_norm_fix": True, # Our fix: capture once, restore once
},
{
"name": "OBLITERATUS basic (1-dir, knee)",
"source": "Ours",
"n_directions": 1,
"layer_selection": "knee",
"regularization": 0.0,
"norm_preserve": False,
"multi_dir_norm_fix": False,
},
{
"name": "OBLITERATUS adv (4-dir, reg=0.3)",
"source": "Ours",
"n_directions": 4,
"layer_selection": "knee",
"regularization": 0.3,
"norm_preserve": True,
"multi_dir_norm_fix": True,
},
{
"name": "OBLITERATUS adv (4-dir, reg=0.1)",
"source": "Ours (tuned)",
"n_directions": 4,
"layer_selection": "knee",
"regularization": 0.1,
"norm_preserve": True,
"multi_dir_norm_fix": True,
},
{
"name": "OBLITERATUS adv (4-dir, reg=0.0)",
"source": "Ours (tuned)",
"n_directions": 4,
"layer_selection": "knee",
"regularization": 0.0,
"norm_preserve": True,
"multi_dir_norm_fix": True,
},
]
results = []
for exp in experiments:
print(f"\n{''*80}")
print(f" {exp['name']}")
print(f" Source: {exp['source']}")
print(f"{''*80}")
t0 = time.time()
# Create fresh model
model, config = create_synthetic_model(hidden_dim, n_layers, n_heads, vocab_size, seq_len)
# Plant ground-truth refusal directions
planted_dirs, planted_subs = plant_refusal_direction(
model, target_layers, hidden_dim,
n_directions=n_planted_dirs,
signal_strength=signal_strength,
seed=42,
)
# Save original weights for capability comparison
original_state = {k: v.clone() for k, v in model.state_dict().items()}
# Measure pre-projection residuals (baseline)
pre_residuals = {}
for idx in target_layers:
pre_residuals[idx] = measure_residual_direction(model, idx, planted_dirs[idx])
# Step 1: Collect activations
harmful_acts = collect_synthetic_activations(
model, n_prompts, seq_len, vocab_size, n_layers,
add_refusal_signal=True,
signal_direction=planted_dirs,
signal_strength=2.0,
seed=100,
)
harmless_acts = collect_synthetic_activations(
model, n_prompts, seq_len, vocab_size, n_layers,
add_refusal_signal=False,
seed=200,
)
# Step 2: Extract directions
ext_dirs, ext_subs, ext_norms = extract_directions(
harmful_acts, harmless_acts, n_layers, exp["n_directions"],
)
# Step 3: Select layers
selected = select_layers(ext_norms, n_layers, exp["layer_selection"])
print(f" Selected layers: {selected}")
# Step 4: Apply projection
apply_projection(
model, selected, ext_subs,
regularization=exp["regularization"],
norm_preserve=exp["norm_preserve"],
multi_dir_norm_fix=exp["multi_dir_norm_fix"],
)
# ── Measure results ──────────────────────────────────────────────
# Direction recovery: cosine similarity between extracted and planted
cos_sims = []
for idx in target_layers:
if idx in ext_dirs and idx in planted_dirs:
cos = F.cosine_similarity(
ext_dirs[idx].unsqueeze(0),
planted_dirs[idx].unsqueeze(0),
).item()
cos_sims.append(abs(cos)) # direction or anti-direction
avg_cos = sum(cos_sims) / len(cos_sims) if cos_sims else 0.0
# Multi-direction subspace recovery: for n_directions>1, measure
# what fraction of the planted subspace is captured
subspace_recovery = []
for idx in target_layers:
if idx in ext_subs and idx in planted_subs:
# Project each planted direction onto extracted subspace
ext_sub = ext_subs[idx] # (k_ext, hidden)
plant_sub = planted_subs[idx] # (k_plant, hidden)
for pi in range(min(plant_sub.shape[0], ext_sub.shape[0])):
# Projection of planted_i onto extracted subspace
proj = ext_sub @ plant_sub[pi] # (k_ext,)
captured = proj.norm().item() # how much is in the subspace
subspace_recovery.append(captured)
avg_subspace = sum(subspace_recovery) / len(subspace_recovery) if subspace_recovery else 0.0
# Residual after projection
post_residuals = {}
for idx in target_layers:
if idx in selected:
post_residuals[idx] = measure_residual_direction(model, idx, planted_dirs[idx])
else:
post_residuals[idx] = pre_residuals[idx] # layer wasn't modified
avg_removal = 0.0
removal_scores = []
for idx in target_layers:
pre = pre_residuals[idx]
post = post_residuals[idx]
if pre > 0:
removal = 1.0 - (post / pre)
removal_scores.append(removal)
avg_removal = sum(removal_scores) / len(removal_scores) if removal_scores else 0.0
# Multi-direction residual: check ALL planted directions
multi_dir_removal = []
for idx in target_layers:
if idx not in selected:
continue
for di in range(planted_subs[idx].shape[0]):
d = planted_subs[idx][di]
pre = measure_residual_direction(
# Need pre-values - approximate from signal_strength
model, idx, d,
)
# Compare to signal strength
multi_dir_removal.append(pre)
avg_multi_residual = sum(multi_dir_removal) / len(multi_dir_removal) if multi_dir_removal else 0.0
# Layer selection accuracy
correct_selected = len(set(selected) & set(target_layers))
false_selected = len(set(selected) - set(target_layers))
missed = len(set(target_layers) - set(selected))
# Capability preservation: Frobenius distance of weights
new_state = model.state_dict()
total_dist = 0.0
for key in original_state:
diff = (new_state[key].float() - original_state[key].float())
total_dist += diff.norm().item() ** 2
total_dist = math.sqrt(total_dist)
# Perplexity proxy: loss on random sequences
losses = []
for _ in range(10):
input_ids = torch.randint(0, vocab_size, (1, seq_len))
with torch.no_grad():
out = model(input_ids, labels=input_ids)
losses.append(out.loss.item())
avg_loss = sum(losses) / len(losses)
ppl = math.exp(min(avg_loss, 100.0))
elapsed = time.time() - t0
result = {
"name": exp["name"],
"source": exp["source"],
"n_directions": exp["n_directions"],
"regularization": exp["regularization"],
"norm_preserve": exp["norm_preserve"],
"direction_recovery": round(avg_cos, 4),
"subspace_recovery": round(avg_subspace, 4),
"primary_removal": round(avg_removal, 4),
"multi_dir_avg_residual": round(avg_multi_residual, 4),
"layers_correct": correct_selected,
"layers_false_positive": false_selected,
"layers_missed": missed,
"n_layers_selected": len(selected),
"weight_distance": round(total_dist, 2),
"perplexity": round(ppl, 2),
"time_seconds": round(elapsed, 2),
}
results.append(result)
print(f" Direction recovery: {avg_cos:.3f} (cosine sim to ground truth)")
print(f" Subspace recovery: {avg_subspace:.3f} (planted dirs captured)")
print(f" Primary dir removal: {avg_removal:.1%} (refusal signal removed)")
print(f" Multi-dir avg residual: {avg_multi_residual:.3f} (lower = better)")
print(f" Layer selection: {correct_selected}/{len(target_layers)} correct, "
f"{false_selected} false+, {missed} missed")
print(f" Weight distance: {total_dist:.2f} (capability delta)")
print(f" Perplexity: {ppl:.2f}")
del model
gc.collect()
return results
def print_table(results: list[dict]):
"""Print formatted comparison tables."""
# ── Table 1: Direction Extraction Quality ──────────────────────────
print(f"\n\n{'='*100}")
print("TABLE 1: DIRECTION EXTRACTION & REMOVAL QUALITY")
print(f"{'='*100}")
print(f"{'Technique':<38} {'Source':<14} {'DirRecov':>9} {'SubRecov':>9} "
f"{'Removal':>8} {'Residual':>9}")
print(f"{''*38} {''*14} {''*9} {''*9} {''*8} {''*9}")
for r in results:
name = r["name"][:37]
source = r["source"][:13]
dr = f"{r['direction_recovery']:.3f}"
sr = f"{r['subspace_recovery']:.3f}"
rm = f"{r['primary_removal']:.1%}"
res = f"{r['multi_dir_avg_residual']:.3f}"
print(f"{name:<38} {source:<14} {dr:>9} {sr:>9} {rm:>8} {res:>9}")
# ── Table 2: Layer Selection & Capability ──────────────────────────
print(f"\n{'='*100}")
print("TABLE 2: LAYER SELECTION & CAPABILITY PRESERVATION")
print(f"{'='*100}")
print(f"{'Technique':<38} {'Layers':>7} {'Correct':>8} {'FalsePos':>9} "
f"{'Missed':>7} {'WeightΔ':>8} {'PPL':>8}")
print(f"{''*38} {''*7} {''*8} {''*9} {''*7} {''*8} {''*8}")
for r in results:
name = r["name"][:37]
print(f"{name:<38} {r['n_layers_selected']:>7} {r['layers_correct']:>8} "
f"{r['layers_false_positive']:>9} {r['layers_missed']:>7} "
f"{r['weight_distance']:>8.2f} {r['perplexity']:>8.2f}")
# ── Table 3: Literature Comparison ────────────────────────────────
print(f"\n\n{'='*110}")
print("TABLE 3: FULL LANDSCAPE — TECHNIQUES, CAPABILITIES, AND REPORTED RESULTS")
print(f"{'='*110}")
print(f"{'Technique':<26} {'Year':>5} {'#Dir':>5} {'Layers':>10} {'NormPres':>9} "
f"{'Reg':>5} {'AutoTune':>9} {'Reported Refusal→':>18} {'Model':>14}")
print(f"{''*26} {''*5} {''*5} {''*10} {''*9} {''*5} {''*9} {''*18} {''*14}")
literature = [
("Arditi et al.", "2024", "1", "top-norm", "No", "0.0", "No",
"~95%→~0%", "Llama-3-8B"),
("FailSpy/abliterator", "2024", "1", "mid-60%", "No", "0.0", "No",
"~90%→~5%", "Llama-3-8B"),
("mlabonne tutorial", "2024", "1", "top-norm", "No", "0.0", "No",
"~90%→~5%", "Llama-3-8B"),
("Gabliteration", "2024", "4-8", "knee", "No", "0.0", "No",
"~95%→~0%", "Various 7B+"),
("grimjim norm-pres", "2024", "4-8", "knee", "Yes(bug)", "0.0", "No",
"~90%→~5%", "Various 7B+"),
("Heretic (p-e-w)", "2025", "float", "kernel", "No", "TPE", "Yes",
"~95%→~0%*", "Gemma-3-12B"),
("Wollschlager cones", "2025", "1-5", "per-layer", "", "", "RDO",
"~98%→~1%", "Llama-3.1-8B"),
("OBLITERATUS basic", "2025", "1", "knee", "No", "0.0", "No",
"~95%→60%**", "Qwen-0.5B"),
("OBLITERATUS advanced", "2025", "4", "knee", "Yes(fix)", "0.3", "No",
"~95%→73%**", "Qwen-0.5B"),
("OBLITERATUS surgical", "2025", "8", "knee", "Yes(fix)", "0.0", "Yes***",
"~95%→0%/broken", "Qwen-0.5B"),
]
for row in literature:
print(f"{row[0]:<26} {row[1]:>5} {row[2]:>5} {row[3]:>10} {row[4]:>9} "
f"{row[5]:>5} {row[6]:>9} {row[7]:>18} {row[8]:>14}")
print("\n * Heretic: 2.8× lower KL divergence than manual abliterations (Gemma-3-12B benchmark)")
print(" ** Our observed results on Qwen2.5-0.5B-Instruct — 0.5B may be too small for linear methods")
print(" *** Surgical combines: whitened SVD + SAE + head surgery + neuron masking + jailbreak contrast")
print(f"{'='*110}")
# ── Analysis ──────────────────────────────────────────────────────
print(f"\n{'='*80}")
print("ANALYSIS: WHY OBLITERATUS UNDERPERFORMS AND WHAT TO FIX")
print(f"{'='*80}")
print("""
ROOT CAUSES (ordered by impact):
1. MODEL SIZE: All published abliteration results use 7B+ models
- Arditi et al.: Llama-3-8B, Gemma-2-9B (hidden_dim=4096+)
- FailSpy: Llama-3-8B
- Heretic: Gemma-3-12B (headline benchmark)
- Wollschlager et al.: Llama-3.1-8B
- OBLITERATUS benchmarks: Qwen-0.5B (hidden_dim=896)
The "single refusal direction" hypothesis may not hold well for small
models. Wollschlager et al. (ICML 2025) showed that refusal lives in
multi-dimensional CONCEPT CONES, and cone dimension scales with model
size. A 0.5B model may encode refusal too diffusely for linear methods.
2. BASIC MODE USES NO CHAT TEMPLATE for activation collection
- The model was trained with chat formatting — without it, activations
during probing don't reflect actual refusal behavior
- This is the single highest-impact config fix
3. ADVANCED MODE REGULARIZATION TOO HIGH (0.3)
- Preserves 30% of refusal component by design
- Combined with 4 directions where later ones capture noise, net
removal is weak
4. SURGICAL MODE DOES TOO MUCH
- 8 directions, whitened SVD, SAE features, neuron masking, head surgery
- Each individually reasonable; together they destroy a 0.5B model
- The whitened SVD un-whitening bug (now fixed) was extracting noise
5. NO BAYESIAN OPTIMIZATION (vs Heretic)
- Heretic's key insight: jointly optimize layer weights, direction
index, and component-specific parameters via TPE
- Minimizes refusal rate AND KL divergence simultaneously
- This automatically handles model-specific tuning that we do manually
RECOMMENDED CONFIG CHANGES:
- basic: use_chat_template → True
- advanced: regularization → 0.1 (from 0.3)
- surgical: n_directions → 4 (from 8), disable safety_neuron_masking
- ALL: Add model-size-aware defaults (n_dirs=1 for <2B, 4 for 2-10B)
- NEW: Add TPE optimization loop (like Heretic) as "optimized" method
""")
def main():
results = run_experiment()
print_table(results)
# Save results
out_path = "/tmp/abliteration_comparison_results.json"
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {out_path}")
if __name__ == "__main__":
main()