Files
OBLITERATUS/tests/test_abliterate.py
2026-03-05 00:50:44 -08:00

2635 lines
102 KiB
Python

"""Tests for the SOTA abliteration pipeline."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock
import pytest
import torch
from transformers import GPT2Config, GPT2LMHeadModel
from obliteratus.abliterate import (
HARMFUL_PROMPTS,
HARMLESS_PROMPTS,
METHODS,
STAGES,
AbliterationPipeline,
PipelineStage,
StageResult,
)
from obliteratus.models.loader import ModelHandle
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_tiny_handle():
"""Create a minimal ModelHandle with a tiny GPT-2 for testing."""
config = GPT2Config(
vocab_size=1000,
n_positions=128,
n_embd=64,
n_layer=4,
n_head=2,
n_inner=256,
)
model = GPT2LMHeadModel(config)
model.eval()
tokenizer = MagicMock()
tokenizer.pad_token = "<pad>"
tokenizer.eos_token = "<eos>"
tokenizer.return_value = {
"input_ids": torch.randint(0, 1000, (1, 10)),
"attention_mask": torch.ones(1, 10, dtype=torch.long),
}
tokenizer.decode.return_value = "The capital of France is Paris, a beautiful city"
handle = ModelHandle(
model=model,
tokenizer=tokenizer,
config=config,
model_name="gpt2-test",
task="causal_lm",
)
handle.snapshot()
return handle
def _make_varied_tokenizer(handle):
"""Set up a tokenizer mock that returns different tokens per call."""
call_count = [0]
def mock_tokenizer(prompt, **kwargs):
call_count[0] += 1
torch.manual_seed(call_count[0])
return {
"input_ids": torch.randint(0, 1000, (1, 5)),
"attention_mask": torch.ones(1, 5, dtype=torch.long),
}
handle.tokenizer.side_effect = mock_tokenizer
@pytest.fixture
def handle():
return _make_tiny_handle()
# ---------------------------------------------------------------------------
# Data & stage definitions
# ---------------------------------------------------------------------------
class TestPrompts:
def test_harmful_prompts_expanded(self):
assert len(HARMFUL_PROMPTS) >= 99
def test_harmless_prompts_expanded(self):
assert len(HARMLESS_PROMPTS) >= 99
def test_prompt_lists_same_length(self):
assert len(HARMFUL_PROMPTS) == len(HARMLESS_PROMPTS)
def test_prompt_count_512(self):
"""512 prompts across 7 severity tiers."""
assert len(HARMFUL_PROMPTS) == 512
assert len(HARMLESS_PROMPTS) == 512
def test_prompt_volume_slicing(self):
"""Slicing at standard volumes gives correct counts."""
for n in (33, 66, 99, 256, 512):
assert len(HARMFUL_PROMPTS[:n]) == n
assert len(HARMLESS_PROMPTS[:n]) == n
class TestStages:
def test_six_stages(self):
assert len(STAGES) == 6
def test_stage_keys(self):
keys = [s.key for s in STAGES]
assert keys == ["summon", "probe", "distill", "excise", "verify", "rebirth"]
def test_stage_dataclass(self):
stage = PipelineStage(key="test", name="TEST", description="A test stage")
assert stage.key == "test"
assert stage.name == "TEST"
def test_stage_result_defaults(self):
result = StageResult(stage="test", status="running")
assert result.message == ""
assert result.duration == 0.0
assert result.details == {}
# ---------------------------------------------------------------------------
# Method presets
# ---------------------------------------------------------------------------
class TestMethods:
def test_methods_exist(self):
assert set(METHODS.keys()) == {"basic", "advanced", "aggressive", "informed", "surgical", "inverted", "nuclear", "optimized", "failspy", "gabliteration", "heretic", "rdo", "spectral_cascade"}
def test_basic_single_direction(self):
cfg = METHODS["basic"]
assert cfg["n_directions"] == 1
assert cfg["norm_preserve"] is False
assert cfg["regularization"] == 0.0
assert cfg["refinement_passes"] == 1
def test_advanced_multi_direction(self):
cfg = METHODS["advanced"]
assert cfg["n_directions"] > 1
assert cfg["norm_preserve"] is True
assert cfg["regularization"] > 0
assert cfg["refinement_passes"] >= 2
def test_aggressive_full_gabliteration(self):
cfg = METHODS["aggressive"]
assert cfg["n_directions"] >= 8
assert cfg["norm_preserve"] is True
assert cfg["refinement_passes"] >= 3
# ---------------------------------------------------------------------------
# Pipeline init
# ---------------------------------------------------------------------------
class TestPipelineInit:
def test_default_prompts(self):
pipeline = AbliterationPipeline(model_name="test-model")
assert pipeline.harmful_prompts == HARMFUL_PROMPTS
assert pipeline.harmless_prompts == HARMLESS_PROMPTS
def test_custom_prompts(self):
harmful = ["bad prompt"]
harmless = ["good prompt"]
pipeline = AbliterationPipeline(
model_name="test-model",
harmful_prompts=harmful,
harmless_prompts=harmless,
)
assert pipeline.harmful_prompts == harmful
assert pipeline.harmless_prompts == harmless
def test_defaults(self):
pipeline = AbliterationPipeline(model_name="test-model")
assert pipeline.device == "auto"
assert pipeline.dtype == "float16"
assert pipeline.output_dir == Path("abliterated")
assert pipeline.trust_remote_code is False
assert pipeline.handle is None
def test_default_method_is_advanced(self):
pipeline = AbliterationPipeline(model_name="test-model")
assert pipeline.method == "advanced"
assert pipeline.n_directions == METHODS["advanced"]["n_directions"]
assert pipeline.norm_preserve == METHODS["advanced"]["norm_preserve"]
assert pipeline.regularization == METHODS["advanced"]["regularization"]
def test_method_basic(self):
pipeline = AbliterationPipeline(model_name="test-model", method="basic")
assert pipeline.n_directions == 1
assert pipeline.norm_preserve is False
assert pipeline.regularization == 0.0
def test_method_aggressive(self):
pipeline = AbliterationPipeline(model_name="test-model", method="aggressive")
assert pipeline.n_directions == 8
assert pipeline.norm_preserve is True
assert pipeline.refinement_passes == 3
def test_explicit_overrides_method(self):
pipeline = AbliterationPipeline(
model_name="test-model",
method="basic",
n_directions=6,
norm_preserve=True,
regularization=0.5,
refinement_passes=4,
)
assert pipeline.n_directions == 6
assert pipeline.norm_preserve is True
assert pipeline.regularization == 0.5
assert pipeline.refinement_passes == 4
def test_callbacks(self):
stage_results = []
log_msgs = []
pipeline = AbliterationPipeline(
model_name="test-model",
on_stage=lambda r: stage_results.append(r),
on_log=lambda m: log_msgs.append(m),
)
pipeline.log("hello")
assert log_msgs == ["hello"]
pipeline._emit("test", "running", "msg")
assert len(stage_results) == 1
assert stage_results[0].stage == "test"
# ---------------------------------------------------------------------------
# _project_out_advanced (norm-preserving + regularization)
# ---------------------------------------------------------------------------
class TestProjectOutAdvanced:
def test_norm_preserving(self):
"""Norm-preserving mode should keep Frobenius norm constant."""
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(4, 8, bias=False)
module = Wrapper()
torch.manual_seed(42)
module.o_proj.weight.data = torch.randn(8, 4)
original_norm = module.o_proj.weight.data.norm().item()
direction = torch.randn(4, 1)
direction = direction / direction.norm()
AbliterationPipeline._project_out_advanced(
module, direction, ["o_proj"], norm_preserve=True, regularization=0.0
)
new_norm = module.o_proj.weight.data.norm().item()
# With amplification cap (1.10x max), exact norm preservation isn't
# guaranteed on tiny matrices (hidden_dim=4) where a single direction
# removes a large fraction of energy. Verify the norm is closer to
# original than the un-preserved norm would be (i.e. cap is working).
without_preserve_norm_sq = original_norm ** 2 - (module.o_proj.weight.data @ direction).pow(2).sum().item()
# The new norm should be >= the un-preserved norm (cap restores some)
assert new_norm >= original_norm * 0.85, \
f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}"
def test_regularization_partial_removal(self):
"""Regularization should preserve some of the refusal component."""
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(4, 8, bias=False)
module_full = Wrapper()
module_reg = Wrapper()
torch.manual_seed(42)
W_orig = torch.randn(8, 4)
module_full.o_proj.weight.data = W_orig.clone()
module_reg.o_proj.weight.data = W_orig.clone()
direction = torch.randn(4, 1)
direction = direction / direction.norm()
# Full removal
AbliterationPipeline._project_out_advanced(
module_full, direction, ["o_proj"], norm_preserve=False, regularization=0.0
)
# Regularized (30% preserved)
AbliterationPipeline._project_out_advanced(
module_reg, direction, ["o_proj"], norm_preserve=False, regularization=0.3
)
W_full = module_full.o_proj.weight.data
W_reg = module_reg.o_proj.weight.data
# Full removal should have zero projection on direction
proj_full = (W_full @ direction).norm().item()
assert proj_full < 1e-4
# Regularized should have non-zero projection (30% preserved)
proj_reg = (W_reg @ direction).norm().item()
proj_orig = (W_orig @ direction).norm().item()
expected_ratio = 0.3
actual_ratio = proj_reg / proj_orig if proj_orig > 0 else 0
assert abs(actual_ratio - expected_ratio) < 0.05, \
f"Expected ~{expected_ratio:.0%} preserved, got {actual_ratio:.0%}"
def test_norm_preserving_transposed(self):
"""Norm-preserving should also work for transposed weights."""
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.c_proj = torch.nn.Linear(8, 4, bias=False)
module = Wrapper()
torch.manual_seed(42)
module.c_proj.weight.data = torch.randn(4, 8)
original_norm = module.c_proj.weight.data.norm().item()
direction = torch.randn(4, 1)
direction = direction / direction.norm()
AbliterationPipeline._project_out_advanced(
module, direction, ["c_proj"], norm_preserve=True, regularization=0.0
)
new_norm = module.c_proj.weight.data.norm().item()
# With amplification cap (1.10x max), exact norm preservation isn't
# guaranteed on tiny matrices where a single direction removes a large
# fraction of energy.
assert new_norm >= original_norm * 0.80, \
f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}"
# ---------------------------------------------------------------------------
# Full attention projection (q/k/v + o_proj)
# ---------------------------------------------------------------------------
class TestAttentionFullProjection:
"""Test that ALL attention weight matrices are projected (not just o_proj)."""
def test_qkv_all_projected(self):
"""q_proj, k_proj, v_proj should all be projected alongside o_proj."""
hidden = 16
class FakeAttn(torch.nn.Module):
def __init__(self):
super().__init__()
self.q_proj = torch.nn.Linear(hidden, hidden, bias=False)
self.k_proj = torch.nn.Linear(hidden, hidden, bias=False)
self.v_proj = torch.nn.Linear(hidden, hidden, bias=False)
self.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
attn = FakeAttn()
torch.manual_seed(42)
for p in attn.parameters():
p.data = torch.randn_like(p.data)
originals = {
name: getattr(attn, name).weight.data.clone()
for name in ["q_proj", "k_proj", "v_proj", "o_proj"]
}
d = torch.randn(hidden, 1)
d = d / d.norm()
from obliteratus.abliterate import _ATTN_OUT_NAMES, _ATTN_IN_NAMES
count = AbliterationPipeline._project_out_advanced(
attn, d, _ATTN_OUT_NAMES + _ATTN_IN_NAMES,
)
assert count == 4, f"Should project 4 weights (q/k/v/o), got {count}"
for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
assert not torch.allclose(
getattr(attn, name).weight.data, originals[name]
), f"{name} should be modified"
def test_project_all_does_not_early_return(self):
"""_project_out_advanced should project ALL matching weights, not just first."""
hidden = 16
class FakeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
self.gate_proj = torch.nn.Linear(hidden, 32, bias=False)
mod = FakeModule()
torch.manual_seed(42)
orig_up = mod.up_proj.weight.data.clone()
orig_gate = mod.gate_proj.weight.data.clone()
d = torch.randn(hidden, 1)
d = d / d.norm()
from obliteratus.abliterate import _FFN_IN_NAMES
count = AbliterationPipeline._project_out_advanced(mod, d, _FFN_IN_NAMES)
assert count == 2, f"Should project both up_proj and gate_proj, got {count}"
assert not torch.allclose(mod.up_proj.weight.data, orig_up), "up_proj should be modified"
assert not torch.allclose(mod.gate_proj.weight.data, orig_gate), "gate_proj should be modified"
def test_lm_head_projection(self):
"""lm_head should be projectable via _project_out_advanced."""
hidden = 16
vocab = 100
class FakeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.lm_head = torch.nn.Linear(hidden, vocab, bias=False)
model = FakeModel()
torch.manual_seed(42)
orig = model.lm_head.weight.data.clone()
d = torch.randn(hidden, 1)
d = d / d.norm()
count = AbliterationPipeline._project_out_advanced(
model, d, ["lm_head"], regularization=0.0,
)
assert count == 1, "Should project lm_head"
assert not torch.allclose(model.lm_head.weight.data, orig), "lm_head should be modified"
# Verify refusal direction is removed from lm_head
proj = (model.lm_head.weight.data @ d).norm().item()
assert proj < 1e-4, f"Refusal direction should be removed from lm_head, proj={proj}"
class TestKneeDetectionThreshold:
"""Test that knee detection uses 5% threshold to include more layers."""
def test_five_percent_threshold_includes_more(self):
"""Layers between 5% and 10% of max should now be included."""
# Layer norms: max=10.0, then several between 5%-10%
sorted_layers = [(0, 10.0), (1, 8.0), (2, 6.0), (3, 0.7), (4, 0.6)]
selected = AbliterationPipeline._select_layers_knee(sorted_layers)
# 0.7 and 0.6 are 7% and 6% of max — should now be included (> 5% threshold)
assert 3 in selected or 4 in selected, (
f"Layers with 6-7% of max signal should be included, got {selected}"
)
# ---------------------------------------------------------------------------
# MoE projection (router, shared expert, input/output, fused)
# ---------------------------------------------------------------------------
class TestProjectMoEExperts:
"""Test the full MoE projection pipeline: router, shared expert, experts."""
def _make_direction(self, hidden_dim=16):
d = torch.randn(hidden_dim, 1)
return d / d.norm()
def test_router_gate_projected(self):
"""Router/gate weight should have refusal direction removed."""
hidden = 16
n_experts = 4
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=True)
self.experts = torch.nn.ModuleList([
self._make_expert() for _ in range(n_experts)
])
@staticmethod
def _make_expert():
m = torch.nn.Module()
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
m.up_proj = torch.nn.Linear(hidden, 32, bias=False)
return m
moe = FakeMoE()
d = self._make_direction(hidden)
W_gate_orig = moe.gate.weight.data.clone()
count = AbliterationPipeline._project_moe_experts(moe, d)
assert count > 0
# Gate weight should have been modified
assert not torch.allclose(moe.gate.weight.data, W_gate_orig), \
"Router/gate weights should be projected"
# The gate weight's projection onto the direction should be ~0
proj = (moe.gate.weight.data @ d).norm().item()
assert proj < 1e-4, f"Gate should have no component along refusal dir, got {proj}"
def test_shared_expert_projected(self):
"""Shared expert (always-on) should have both input and output projected."""
hidden = 16
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, 2, bias=False)
self.shared_expert = torch.nn.Module()
self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False)
self.experts = torch.nn.ModuleList([
self._make_expert() for _ in range(2)
])
@staticmethod
def _make_expert():
m = torch.nn.Module()
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
m.up_proj = torch.nn.Linear(hidden, 32, bias=False)
return m
moe = FakeMoE()
d = self._make_direction(hidden)
shared_down_orig = moe.shared_expert.down_proj.weight.data.clone()
shared_up_orig = moe.shared_expert.up_proj.weight.data.clone()
count = AbliterationPipeline._project_moe_experts(moe, d)
assert count > 0
# Both shared expert output AND input projections should be modified
assert not torch.allclose(moe.shared_expert.down_proj.weight.data, shared_down_orig), \
"Shared expert output (down_proj) should be projected"
assert not torch.allclose(moe.shared_expert.up_proj.weight.data, shared_up_orig), \
"Shared expert input (up_proj) should be projected"
def test_expert_input_projections_projected(self):
"""Expert input projections (up_proj, gate_proj) should also be modified."""
hidden = 16
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
self.gate_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(2)])
moe = FakeMoE()
d = self._make_direction(hidden)
up_orig = moe.experts[0].up_proj.weight.data.clone()
count = AbliterationPipeline._project_moe_experts(moe, d)
# Each expert contributes 2 projections (output + input)
# 2 experts * 2 = 4 minimum
assert count >= 4, f"Expected >= 4 projections (out+in per expert), got {count}"
assert not torch.allclose(moe.experts[0].up_proj.weight.data, up_orig), \
"Expert input (up_proj) should be projected"
def test_fused_3d_output_and_input(self):
"""Fused 3D parameter patterns (GPT-OSS style) should project both directions."""
hidden = 16
intermediate = 32
n_experts = 4
class FusedExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
self.up_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.experts = FusedExperts()
moe = FakeMoE()
d = self._make_direction(hidden)
down_orig = moe.experts.down_proj.data.clone()
up_orig = moe.experts.up_proj.data.clone()
count = AbliterationPipeline._project_moe_experts(moe, d)
# 4 experts output + 4 experts input = 8
assert count == 8, f"Expected 8 fused projections, got {count}"
assert not torch.allclose(moe.experts.down_proj.data, down_orig), \
"Fused output (down_proj) should be projected"
assert not torch.allclose(moe.experts.up_proj.data, up_orig), \
"Fused input (up_proj) should be projected"
def test_fused_3d_norm_preserve(self):
"""Fused 3D projections should preserve norms when requested."""
hidden = 16
intermediate = 32
n_experts = 4
class FusedExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.experts = FusedExperts()
moe = FakeMoE()
d = self._make_direction(hidden)
# Record per-expert norms before
orig_norms = [moe.experts.down_proj.data[i].norm().item() for i in range(n_experts)]
AbliterationPipeline._project_moe_experts(moe, d, norm_preserve=True)
# Check per-expert norms preserved
for i in range(n_experts):
new_norm = moe.experts.down_proj.data[i].norm().item()
assert abs(orig_norms[i] - new_norm) < 1e-3, \
f"Expert {i} norm not preserved: {orig_norms[i]:.4f} vs {new_norm:.4f}"
def test_no_experts_returns_zero(self):
"""Module without experts attribute should return 0."""
class NoMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Linear(16, 32)
moe = NoMoE()
d = self._make_direction(16)
assert AbliterationPipeline._project_moe_experts(moe, d) == 0
def test_router_bias_projected(self):
"""Router bias should be projected when project_biases=True."""
hidden = 16
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, 4, bias=True)
self.experts = torch.nn.ModuleList([
self._make_expert() for _ in range(4)
])
@staticmethod
def _make_expert():
m = torch.nn.Module()
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
return m
moe = FakeMoE()
d = self._make_direction(hidden)
bias_orig = moe.gate.bias.data.clone()
count = AbliterationPipeline._project_moe_experts(moe, d, project_biases=True)
# Gate has 4 outputs (num_experts), direction has 16 dims
# bias shape (4,) != direction shape (16,), so bias won't match.
# This is correct: router bias is (num_experts,), not (hidden_dim,),
# so _project_bias won't modify it (shape mismatch is expected).
assert torch.allclose(moe.gate.bias.data, bias_orig), (
"Router bias should be unchanged when shape mismatches direction"
)
assert isinstance(count, int)
assert count > 0 # expert weights should still be projected
def test_router_auto_detection_fallback(self):
"""Unknown router name should be auto-detected and projected."""
import warnings as w
hidden = 16
n_experts = 4
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
# Unusual router name not in _ROUTER_NAMES
self.moe_gate_proj = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([
self._make_expert() for _ in range(n_experts)
])
@staticmethod
def _make_expert():
m = torch.nn.Module()
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
return m
moe = FakeMoE()
d = self._make_direction(hidden)
gate_orig = moe.moe_gate_proj.weight.data.clone()
with w.catch_warnings(record=True) as caught:
w.simplefilter("always")
AbliterationPipeline._project_moe_experts(moe, d)
# Should auto-detect and project the unusual router name
assert not torch.allclose(moe.moe_gate_proj.weight.data, gate_orig), \
"Auto-detected router should be projected"
# Should emit a warning about the auto-detection
auto_detect_warnings = [
x for x in caught
if "auto-detected" in str(x.message)
]
assert len(auto_detect_warnings) > 0, "Should warn about auto-detected router"
def test_full_moe_all_components(self):
"""End-to-end: all MoE components should be modified together."""
hidden = 16
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, 4, bias=False)
self.shared_expert = torch.nn.Module()
self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(4)])
moe = FakeMoE()
d = self._make_direction(hidden)
count = AbliterationPipeline._project_moe_experts(moe, d)
# Expected: 1 (gate) + 2 (shared out+in) + 4*2 (expert out+in) = 11
assert count == 11, f"Expected 11 total projections, got {count}"
# ---------------------------------------------------------------------------
# SOTA technique #1: Safety-neuron masking (GateBreaker-style z-score)
# ---------------------------------------------------------------------------
class TestSafetyNeuronMasking:
def test_outlier_neurons_zeroed(self):
"""Neurons with outsized refusal projection should be zeroed."""
hidden = 16
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 64, bias=False)
module = Wrapper()
torch.manual_seed(42)
# Inject a few rows with very high projection along direction
d = torch.randn(hidden, 1)
d = d / d.norm()
# Make rows 0,1,2 have huge projection (outliers)
for i in range(3):
module.down_proj.weight.data[i] = d.squeeze() * 10.0
n_masked = AbliterationPipeline._mask_safety_neurons(
module, d, ["down_proj"], z_threshold=2.0,
)
assert n_masked >= 3, f"Expected >= 3 masked neurons, got {n_masked}"
# Masked rows should be zero
for i in range(3):
assert module.down_proj.weight.data[i].abs().max().item() < 1e-6
def test_no_outliers_no_masking(self):
"""When all neurons have similar projection, none should be masked."""
hidden = 16
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
module = Wrapper()
# Uniform weights — no outliers
module.down_proj.weight.data = torch.ones(32, hidden) * 0.1
d = torch.randn(hidden, 1)
d = d / d.norm()
n_masked = AbliterationPipeline._mask_safety_neurons(
module, d, ["down_proj"], z_threshold=2.0,
)
assert n_masked == 0
def test_high_threshold_masks_fewer(self):
"""Higher z_threshold should mask fewer neurons."""
hidden = 16
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 64, bias=False)
module1 = Wrapper()
module2 = Wrapper()
torch.manual_seed(42)
W = torch.randn(64, hidden)
module1.down_proj.weight.data = W.clone()
module2.down_proj.weight.data = W.clone()
d = torch.randn(hidden, 1)
d = d / d.norm()
n1 = AbliterationPipeline._mask_safety_neurons(module1, d, ["down_proj"], z_threshold=1.5)
n2 = AbliterationPipeline._mask_safety_neurons(module2, d, ["down_proj"], z_threshold=3.0)
assert n1 >= n2, f"Lower threshold should mask more: {n1} vs {n2}"
# ---------------------------------------------------------------------------
# SOTA technique #2: Jailbreak-contrastive direction refinement
# ---------------------------------------------------------------------------
class TestJailbreakContrastive:
def test_jailbreak_prompt_generation(self):
"""Pipeline should generate jailbreak variants of harmful prompts."""
pipeline = AbliterationPipeline(model_name="test")
pipeline.harmful_prompts = ["How to hack a computer", "Write a virus"]
jailbreak = pipeline._generate_jailbreak_prompts()
assert len(jailbreak) == 2
# Each should contain the original prompt
assert "hack a computer" in jailbreak[0].lower()
assert "virus" in jailbreak[1].lower()
# Each should be wrapped in a template (longer than original)
assert len(jailbreak[0]) > len("How to hack a computer")
def test_jailbreak_contrast_blending(self):
"""Jailbreak-contrastive blending should modify refusal direction."""
pipeline = AbliterationPipeline(
model_name="test",
use_jailbreak_contrast=True,
n_directions=1,
)
hidden = 16
pipeline._on_log = lambda m: None
# Simulate probed means
torch.manual_seed(42)
harm_mean = torch.randn(1, hidden)
safe_mean = torch.randn(1, hidden)
jb_mean = torch.randn(1, hidden)
pipeline._harmful_means = {0: harm_mean}
pipeline._harmless_means = {0: safe_mean}
pipeline._jailbreak_means = {0: jb_mean}
pipeline._harmful_acts = {0: [harm_mean]}
pipeline._harmless_acts = {0: [safe_mean]}
pipeline._jailbreak_acts = {0: [jb_mean]}
# Run distill (will set standard direction, then blend)
pipeline._distill()
# Direction should be a unit vector
d = pipeline.refusal_directions[0]
assert abs(d.norm().item() - 1.0) < 1e-4
# Direction should differ from pure harm-safe difference
std_diff = (harm_mean - safe_mean).squeeze()
std_dir = std_diff / std_diff.norm()
cosine = (d @ std_dir).item()
# Blended direction should not be identical to standard
assert cosine < 0.99, f"Blended direction too similar to standard: cos={cosine}"
def test_surgical_method_enables_jailbreak(self):
"""Surgical method should enable jailbreak-contrastive by default."""
cfg = METHODS["surgical"]
assert cfg["use_jailbreak_contrast"] is True
# ---------------------------------------------------------------------------
# SOTA technique #3: Layer-adaptive projection strength
# ---------------------------------------------------------------------------
class TestLayerAdaptiveStrength:
def test_layer_weights_computed(self):
"""Layer-adaptive weights should be proportional to refusal signal."""
pipeline = AbliterationPipeline(
model_name="test",
layer_adaptive_strength=True,
n_directions=1,
)
hidden = 16
pipeline._on_log = lambda m: None
# Simulate: layer 0 has strong signal, layer 1 weak
torch.manual_seed(42)
strong_diff = torch.randn(1, hidden) * 10.0
weak_diff = torch.randn(1, hidden) * 1.0
zero_mean = torch.zeros(1, hidden)
pipeline._harmful_means = {0: strong_diff, 1: weak_diff}
pipeline._harmless_means = {0: zero_mean, 1: zero_mean}
pipeline._harmful_acts = {0: [strong_diff], 1: [weak_diff]}
pipeline._harmless_acts = {0: [zero_mean], 1: [zero_mean]}
pipeline._distill()
# Layer weights should exist for strong layers
assert len(pipeline._layer_excise_weights) > 0
# Strongest layer should have weight ~1.0
max_weight = max(pipeline._layer_excise_weights.values())
assert max_weight > 0.9, f"Max weight should be ~1.0, got {max_weight}"
def test_surgical_method_enables_adaptive(self):
"""Surgical method should enable layer-adaptive by default."""
cfg = METHODS["surgical"]
assert cfg["layer_adaptive_strength"] is True
# ---------------------------------------------------------------------------
# SOTA technique #5: Attention head surgery
# ---------------------------------------------------------------------------
class TestAttentionHeadSurgery:
def test_head_selective_projection(self):
"""Selective head projection should only modify targeted head rows."""
hidden = 16
n_heads = 4
head_dim = hidden // n_heads
class FakeAttn(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
attn = FakeAttn()
torch.manual_seed(42)
W_orig = attn.o_proj.weight.data.clone()
d = torch.randn(hidden, 1)
d = d / d.norm()
# Head scores: head 0 is top safety head, head 3 is lowest
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
n_modified = AbliterationPipeline._project_head_selective(
attn, d, head_scores, n_heads=n_heads, head_fraction=0.25,
)
assert n_modified >= 1, "Should modify at least 1 head"
W_new = attn.o_proj.weight.data
# Head 0 columns (targeted) should be modified
assert not torch.allclose(
W_new[:, 0:head_dim], W_orig[:, 0:head_dim]
), "Targeted head 0 should be modified"
# Head 3 columns (NOT targeted) should be untouched
assert torch.allclose(
W_new[:, 3*head_dim:4*head_dim],
W_orig[:, 3*head_dim:4*head_dim],
), "Non-targeted head 3 should be untouched"
def test_head_surgery_norm_preserve(self):
"""Head surgery with norm_preserve should maintain per-head norms."""
hidden = 16
n_heads = 4
head_dim = hidden // n_heads
class FakeAttn(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
attn = FakeAttn()
torch.manual_seed(42)
d = torch.randn(hidden, 1)
d = d / d.norm()
orig_norms = [
attn.o_proj.weight.data[:, h*head_dim:(h+1)*head_dim].norm().item()
for h in range(n_heads)
]
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
AbliterationPipeline._project_head_selective(
attn, d, head_scores, n_heads=n_heads,
head_fraction=0.5, norm_preserve=True,
)
# Targeted heads should have preserved norms
for h in range(2): # top 50% = 2 heads
new_norm = attn.o_proj.weight.data[:, h*head_dim:(h+1)*head_dim].norm().item()
assert abs(orig_norms[h] - new_norm) < 1e-3, \
f"Head {h} norm not preserved: {orig_norms[h]:.4f} vs {new_norm:.4f}"
def test_head_surgery_non_square_gqa(self):
"""Head surgery should work for GQA models with non-square o_proj (attn_dim != hidden_dim)."""
hidden_dim = 12 # model hidden dimension
attn_dim = 32 # attention dimension (n_heads * head_dim_attn)
n_heads = 4
head_dim_attn = attn_dim // n_heads # 8
class FakeAttnGQA(torch.nn.Module):
def __init__(self):
super().__init__()
# o_proj maps attn_dim -> hidden_dim
# nn.Linear weight shape: (hidden_dim, attn_dim) = (12, 32)
self.o_proj = torch.nn.Linear(attn_dim, hidden_dim, bias=False)
attn = FakeAttnGQA()
torch.manual_seed(42)
attn.o_proj.weight.data = torch.randn(hidden_dim, attn_dim)
W_orig = attn.o_proj.weight.data.clone()
d = torch.randn(hidden_dim, 1)
d = d / d.norm()
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
n_modified = AbliterationPipeline._project_head_selective(
attn, d, head_scores, n_heads=n_heads, head_fraction=0.25,
)
assert n_modified >= 1, "Should modify at least 1 head"
W_new = attn.o_proj.weight.data
# Head 0 columns (targeted) should be modified
assert not torch.allclose(
W_new[:, 0:head_dim_attn], W_orig[:, 0:head_dim_attn]
), "Targeted head 0 should be modified"
# Head 3 columns (NOT targeted) should be untouched
assert torch.allclose(
W_new[:, 3*head_dim_attn:4*head_dim_attn],
W_orig[:, 3*head_dim_attn:4*head_dim_attn],
), "Non-targeted head 3 should be untouched"
def test_head_surgery_gqa_norm_preserve(self):
"""Head surgery on GQA non-square o_proj with norm_preserve."""
hidden_dim = 12
attn_dim = 32
n_heads = 4
head_dim_attn = attn_dim // n_heads
class FakeAttnGQA(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(attn_dim, hidden_dim, bias=False)
attn = FakeAttnGQA()
torch.manual_seed(42)
attn.o_proj.weight.data = torch.randn(hidden_dim, attn_dim)
d = torch.randn(hidden_dim, 1)
d = d / d.norm()
orig_norms = [
attn.o_proj.weight.data[:, h*head_dim_attn:(h+1)*head_dim_attn].norm().item()
for h in range(n_heads)
]
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
AbliterationPipeline._project_head_selective(
attn, d, head_scores, n_heads=n_heads,
head_fraction=0.5, norm_preserve=True,
)
for h in range(2): # top 50% = 2 heads
new_norm = attn.o_proj.weight.data[:, h*head_dim_attn:(h+1)*head_dim_attn].norm().item()
assert abs(orig_norms[h] - new_norm) < 1e-3, \
f"GQA head {h} norm not preserved: {orig_norms[h]:.4f} vs {new_norm:.4f}"
# ---------------------------------------------------------------------------
# SOTA technique #6: SAE feature-level abliteration
# ---------------------------------------------------------------------------
class TestSAEAbliteration:
def test_sae_train_and_reconstruct(self):
"""SAE should train and reconstruct activations."""
from obliteratus.analysis.sae_abliteration import train_sae
hidden = 32
# Generate synthetic activations
torch.manual_seed(42)
acts = [torch.randn(hidden) for _ in range(64)]
sae = train_sae(acts, hidden, expansion=2, n_epochs=10, lr=1e-3)
# Forward pass should work
x = torch.randn(1, hidden)
x_hat, z = sae(x)
assert x_hat.shape == x.shape
assert z.shape == (1, 2 * hidden) # expansion=2
# Z should be sparse (ReLU activation)
assert (z == 0).float().mean() > 0.3, "Features should be sparse"
def test_refusal_feature_identification(self):
"""SAE should identify features that differ between harmful/harmless."""
from obliteratus.analysis.sae_abliteration import (
train_sae, identify_refusal_features,
)
hidden = 32
torch.manual_seed(42)
# Create activations with clear harmful/harmless separation
refusal_dir = torch.randn(hidden)
refusal_dir = refusal_dir / refusal_dir.norm()
harmful_acts = [torch.randn(hidden) + 2.0 * refusal_dir for _ in range(32)]
harmless_acts = [torch.randn(hidden) - 2.0 * refusal_dir for _ in range(32)]
all_acts = harmful_acts + harmless_acts
sae = train_sae(all_acts, hidden, expansion=2, n_epochs=30, lr=3e-4)
result = identify_refusal_features(
sae, harmful_acts, harmless_acts, layer_idx=0, top_k=4,
)
assert result.n_refusal_features == 4
assert result.sae_directions.shape == (4, hidden)
assert result.variance_explained > 0.0
# SAE directions should have some alignment with the actual refusal direction
best_cos = max(
abs((result.sae_directions[i] @ refusal_dir).item())
for i in range(result.sae_directions.shape[0])
)
assert best_cos > 0.1, f"SAE should find direction aligned with refusal: best_cos={best_cos}"
def test_sae_directions_unit_norm(self):
"""SAE-derived directions should be unit normalized."""
from obliteratus.analysis.sae_abliteration import (
train_sae, identify_refusal_features,
)
hidden = 16
torch.manual_seed(42)
harmful = [torch.randn(hidden) + torch.ones(hidden) for _ in range(16)]
harmless = [torch.randn(hidden) - torch.ones(hidden) for _ in range(16)]
sae = train_sae(harmful + harmless, hidden, expansion=2, n_epochs=10)
result = identify_refusal_features(sae, harmful, harmless, 0, top_k=3)
for i in range(result.sae_directions.shape[0]):
norm = result.sae_directions[i].norm().item()
assert abs(norm - 1.0) < 1e-3, f"Direction {i} norm={norm}, expected 1.0"
# ---------------------------------------------------------------------------
# Surgical method preset
# ---------------------------------------------------------------------------
class TestSurgicalMethod:
def test_surgical_enables_all_sota(self):
"""Surgical method should enable all 6 SOTA techniques."""
cfg = METHODS["surgical"]
assert cfg["use_jailbreak_contrast"] is True
assert cfg["layer_adaptive_strength"] is True
assert cfg["safety_neuron_masking"] is True
assert cfg["per_expert_directions"] is True
assert cfg["attention_head_surgery"] is True
assert cfg["use_sae_features"] is True
def test_basic_disables_all_sota(self):
"""Basic method should not enable SOTA techniques (no keys or False)."""
cfg = METHODS["basic"]
assert cfg.get("use_jailbreak_contrast", False) is False
assert cfg.get("layer_adaptive_strength", False) is False
assert cfg.get("safety_neuron_masking", False) is False
def test_pipeline_init_surgical(self):
"""Pipeline initialized with surgical method should have all flags set."""
pipeline = AbliterationPipeline(model_name="test", method="surgical")
assert pipeline.use_jailbreak_contrast is True
assert pipeline.layer_adaptive_strength is True
assert pipeline.safety_neuron_masking is True
assert pipeline.per_expert_directions is True
assert pipeline.attention_head_surgery is True
assert pipeline.use_sae_features is True
def test_pipeline_init_explicit_override(self):
"""Explicit params should override method defaults."""
pipeline = AbliterationPipeline(
model_name="test", method="surgical",
safety_neuron_masking=False,
)
assert pipeline.safety_neuron_masking is False
assert pipeline.use_jailbreak_contrast is True # rest still from surgical
# ---------------------------------------------------------------------------
# Inverted method (semantic refusal inversion)
# ---------------------------------------------------------------------------
class TestInvertedMethod:
def test_inverted_preset_config(self):
"""Inverted method preset should enable inversion flag."""
cfg = METHODS["inverted"]
assert cfg["invert_refusal"] is True
assert cfg["n_directions"] == 8
assert cfg["use_jailbreak_contrast"] is True
def test_surgical_does_not_invert(self):
"""Surgical method should NOT enable inversion by default."""
cfg = METHODS["surgical"]
assert cfg.get("invert_refusal", False) is False
def test_pipeline_init_inverted(self):
"""Pipeline initialized with inverted method should have flag set."""
pipeline = AbliterationPipeline(model_name="test", method="inverted")
assert pipeline.invert_refusal is True
assert pipeline.use_jailbreak_contrast is True
assert pipeline.safety_neuron_masking is False # zeroing + reflection is destructive
def test_pipeline_invert_explicit_override(self):
"""Explicit invert_refusal param should override method default."""
pipeline = AbliterationPipeline(
model_name="test", method="surgical", invert_refusal=True,
)
assert pipeline.invert_refusal is True
pipeline2 = AbliterationPipeline(
model_name="test", method="inverted", invert_refusal=False,
)
assert pipeline2.invert_refusal is False
def test_reflection_math(self):
"""2x projection (reflection) should negate the refusal component."""
hidden = 16
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(hidden, 32, bias=False)
module = Wrapper()
torch.manual_seed(42)
W_orig = module.o_proj.weight.data.clone()
d = torch.randn(hidden, 1)
d = d / d.norm()
# Original projection onto d
orig_proj = (W_orig @ d).squeeze()
# Reflection: regularization=-1.0 → scale=2.0
AbliterationPipeline._project_out_advanced(
module, d, ["o_proj"], regularization=-1.0,
)
W_reflected = module.o_proj.weight.data
new_proj = (W_reflected @ d).squeeze()
# After reflection, projection should be NEGATED (sign flipped)
assert torch.allclose(new_proj, -orig_proj, atol=1e-4), (
f"Reflected projection should be negated: expected ~{-orig_proj[:3]} got {new_proj[:3]}"
)
def test_reflection_preserves_orthogonal_component(self):
"""Reflection should not change the component perpendicular to d."""
hidden = 8
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(hidden, 16, bias=False)
module = Wrapper()
torch.manual_seed(42)
W_orig = module.o_proj.weight.data.clone()
d = torch.randn(hidden, 1)
d = d / d.norm()
# Compute original orthogonal component
orig_d_component = (W_orig @ d) @ d.T # rank-1 matrix: projection onto d
orig_ortho = W_orig - orig_d_component # everything except d-component
AbliterationPipeline._project_out_advanced(
module, d, ["o_proj"], regularization=-1.0,
)
W_reflected = module.o_proj.weight.data
new_d_component = (W_reflected @ d) @ d.T
new_ortho = W_reflected - new_d_component
# Orthogonal component should be unchanged
assert torch.allclose(orig_ortho, new_ortho, atol=1e-4), (
"Reflection should preserve orthogonal component"
)
def test_moe_expert_safety_classification(self):
"""_identify_safety_experts should classify experts by router affinity."""
hidden = 16
n_experts = 4
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([
torch.nn.Linear(hidden, hidden) for _ in range(n_experts)
])
class FakeLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = torch.nn.Module()
self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
self.mlp = FakeMoE()
from obliteratus.models.loader import ModelHandle
from unittest.mock import MagicMock
from transformers import GPT2Config
config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
model = MagicMock()
model.parameters.return_value = iter([torch.zeros(1)])
handle = ModelHandle(
model=model, tokenizer=MagicMock(),
config=config, model_name="test", task="causal_lm",
)
pipeline = AbliterationPipeline(model_name="test", method="inverted")
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
# Set up fake layer and direction
layer = FakeLayer()
torch.manual_seed(42)
# Make router weight so expert 0 has highest affinity for d
d = torch.randn(hidden)
d = d / d.norm()
# Set router weights: expert 0 aligned with d, expert 3 anti-aligned
layer.mlp.gate.weight.data[0] = d * 5.0
layer.mlp.gate.weight.data[1] = d * 1.0
layer.mlp.gate.weight.data[2] = d * -1.0
layer.mlp.gate.weight.data[3] = d * -5.0
# Mock get_layer_modules to return our fake layer
import obliteratus.abliterate as abl_module
orig_get_layers = abl_module.get_layer_modules
orig_get_ffn = abl_module.get_ffn_module
abl_module.get_layer_modules = lambda h: [layer]
abl_module.get_ffn_module = lambda lay, a: lay.mlp
try:
pipeline.refusal_directions = {0: d}
pipeline._strong_layers = [0]
pipeline._identify_safety_experts()
finally:
abl_module.get_layer_modules = orig_get_layers
abl_module.get_ffn_module = orig_get_ffn
assert 0 in pipeline._expert_safety_scores
scores = pipeline._expert_safety_scores[0]
# Expert 0 should be highest safety affinity
assert scores[0][0] == 0, f"Expert 0 should be top safety, got {scores[0]}"
# Expert 3 should be lowest
assert scores[-1][0] == 3, f"Expert 3 should be lowest, got {scores[-1]}"
def test_moe_inverted_excision_selective(self):
"""Inverted MoE excision should reflect safety experts and remove from capability."""
hidden = 16
n_experts = 4
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, hidden, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
moe = FakeMoE()
torch.manual_seed(42)
for p in moe.parameters():
p.data = torch.randn_like(p.data)
d = torch.randn(hidden, 1)
d = d / d.norm()
# Set up safety scores: experts 0,1 are safety, 2,3 are capability
pipeline = AbliterationPipeline(model_name="test", method="inverted")
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
pipeline._expert_safety_scores = {
0: [(0, 5.0), (1, 3.0), (2, -1.0), (3, -3.0)]
}
orig_router = moe.gate.weight.data.clone()
count = pipeline._project_moe_experts_inverted(
moe, d, 0, norm_preserve=False, project_biases=False,
)
assert count > 0, "Should project some weights"
# Router should be reflected (capped at 1.5x to prevent extreme logits
# that cause CUDA illegal memory access in batched expert forward).
# With router_reg = max(reflect_reg, -0.5) → scale = 1.5:
# new_proj ≈ orig_proj - 1.5 * orig_proj = -0.5 * orig_proj
# Additionally, _stabilize_router_weights clamps outliers, so we
# verify the sign is flipped and magnitude is substantial.
router_proj = (moe.gate.weight.data @ d.squeeze()).squeeze()
orig_router_proj = (orig_router @ d.squeeze()).squeeze()
cosine = torch.nn.functional.cosine_similarity(
router_proj.unsqueeze(0), -orig_router_proj.unsqueeze(0),
)
assert cosine > 0.5, (
f"Router projection should be at least partially reflected, cosine={cosine.item():.3f}"
)
# Safety expert 0: should be reflected (projection negated)
e0_proj = (moe.experts[0].down_proj.weight.data @ d).norm()
# After reflection the projection doesn't go to zero — it negates
assert e0_proj > 1e-4, "Safety expert should have non-zero projection (reflected, not removed)"
# Capability expert 3: should have projection removed (near zero)
e3_proj = (moe.experts[3].down_proj.weight.data @ d).norm().item()
assert e3_proj < 1e-3, f"Capability expert should have projection removed, got {e3_proj}"
# ---------------------------------------------------------------------------
# Nuclear method
# ---------------------------------------------------------------------------
class TestNuclearMethod:
def test_nuclear_preset_config(self):
"""Nuclear method should match inverted baseline + permanent weight techniques."""
cfg = METHODS["nuclear"]
assert cfg["invert_refusal"] is True
assert cfg["n_directions"] == 4 # fewer than inverted to avoid over-ablation
assert cfg["refinement_passes"] == 2 # same as inverted
assert cfg["reflection_strength"] == 1.25 # tempered for CoT coherence
assert cfg["project_embeddings"] is True
assert cfg["embed_regularization"] == 0.50 # conservative cascade limit
assert cfg["activation_steering"] is True # residual cleanup hooks
assert cfg["steering_strength"] == 0.15 # light residual correction
assert cfg["expert_transplant"] is True
assert cfg["transplant_blend"] == 0.10 # gentle nudge, not overwrite
assert cfg["use_jailbreak_contrast"] is True
assert cfg["attention_head_surgery"] is True
assert cfg["layer_adaptive_strength"] is True # per-layer scaling
def test_nuclear_pipeline_init(self):
"""Pipeline initialized with nuclear method should have all flags set."""
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
assert pipeline.invert_refusal is True
assert pipeline.reflection_strength == 1.25
assert pipeline.embed_regularization == 0.50
assert pipeline.transplant_blend == 0.10
assert pipeline.project_embeddings is True
assert pipeline.activation_steering is True # residual cleanup
assert pipeline.expert_transplant is True
assert pipeline.n_directions == 4
assert pipeline.refinement_passes == 2
assert pipeline.layer_adaptive_strength is True
def test_reflection_strength_configurable(self):
"""reflection_strength should be explicitly overridable."""
pipeline = AbliterationPipeline(
model_name="test", method="inverted", reflection_strength=3.0,
)
assert pipeline.reflection_strength == 3.0
def test_inverted_default_strength_is_2(self):
"""Inverted method should default to reflection_strength=2.0."""
pipeline = AbliterationPipeline(model_name="test", method="inverted")
assert pipeline.reflection_strength == 2.0
def test_boosted_reflection_math(self):
"""2.5x reflection should produce stronger negation than 2x."""
hidden = 16
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(hidden, 32, bias=False)
d = torch.randn(hidden, 1)
d = d / d.norm()
# 2x reflection
module_2x = Wrapper()
torch.manual_seed(42)
module_2x.o_proj.weight.data = torch.randn(32, hidden)
orig = module_2x.o_proj.weight.data.clone()
AbliterationPipeline._project_out_advanced(
module_2x, d, ["o_proj"], regularization=-1.0, # scale=2.0
)
proj_2x = (module_2x.o_proj.weight.data @ d).squeeze()
# 2.5x reflection
module_25x = Wrapper()
module_25x.o_proj.weight.data = orig.clone()
AbliterationPipeline._project_out_advanced(
module_25x, d, ["o_proj"], regularization=-1.5, # scale=2.5
)
proj_25x = (module_25x.o_proj.weight.data @ d).squeeze()
# 2.5x should be 25% stronger negation than 2x
assert proj_25x.norm() > proj_2x.norm(), (
"2.5x reflection should produce stronger (more negative) projection than 2x"
)
def test_activation_steering_hook(self):
"""Steering hooks should subtract refusal direction from hidden states."""
hidden = 8
class FakeLayer(torch.nn.Module):
def forward(self, x):
return x
layer = FakeLayer()
layers = torch.nn.ModuleList([layer])
# Explicitly enable steering (nuclear preset has it off by default)
pipeline = AbliterationPipeline(
model_name="test", method="inverted", activation_steering=True,
steering_strength=0.5,
)
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
d = torch.randn(hidden)
d = d / d.norm()
pipeline.refusal_directions = {0: d}
pipeline._strong_layers = [0]
n_hooks = pipeline._install_activation_steering(layers)
assert n_hooks == 1
assert len(pipeline._steering_hooks) == 1
# Create a hidden state with strong refusal component
batch = torch.randn(1, 4, hidden)
refusal_component = 5.0 * d.unsqueeze(0).unsqueeze(0).expand_as(batch)
input_hidden = batch + refusal_component
# Run through the layer (hook should fire)
output = layer(input_hidden)
# The refusal component should be reduced
proj_before = torch.einsum("bsh,h->bs", input_hidden, d).abs().mean()
proj_after = torch.einsum("bsh,h->bs", output, d).abs().mean()
assert proj_after < proj_before, (
f"Steering should reduce refusal projection: before={proj_before:.3f}, after={proj_after:.3f}"
)
# Cleanup
for hook in pipeline._steering_hooks:
hook.remove()
def test_expert_transplant(self):
"""Expert transplant should overwrite safety expert weights with capability average."""
hidden = 16
n_experts = 4
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, hidden, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
class FakeLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = torch.nn.Module()
self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
self.mlp = FakeMoE()
layer = FakeLayer()
layers = torch.nn.ModuleList([layer])
torch.manual_seed(42)
for p in layer.parameters():
p.data = torch.randn_like(p.data)
# Save original safety expert weight
orig_safety0 = layer.mlp.experts[0].down_proj.weight.data.clone()
# Save capability expert weights for computing expected mean
# With top-third classification (n_experts // 3 = 1), only expert 0
# is safety; experts 1, 2, 3 are all capability.
cap1 = layer.mlp.experts[1].down_proj.weight.data.clone()
cap2 = layer.mlp.experts[2].down_proj.weight.data.clone()
cap3 = layer.mlp.experts[3].down_proj.weight.data.clone()
expected_mean = (cap1 + cap2 + cap3) / 3.0
import obliteratus.abliterate as abl_module
from obliteratus.models.loader import ModelHandle
from transformers import GPT2Config
config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
model = MagicMock()
model.parameters.return_value = iter([torch.zeros(1)])
handle = ModelHandle(model=model, tokenizer=MagicMock(), config=config, model_name="test", task="causal_lm")
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
pipeline._strong_layers = [0]
# Experts 0,1 are safety (high affinity), 2,3 are capability
pipeline._expert_safety_scores = {
0: [(0, 5.0), (1, 3.0), (2, -1.0), (3, -3.0)]
}
orig_get_ffn = abl_module.get_ffn_module
abl_module.get_ffn_module = lambda lay, a: lay.mlp
try:
count = pipeline._transplant_expert_weights(layers)
finally:
abl_module.get_ffn_module = orig_get_ffn
assert count >= 1, f"Should blend at least 1 weight (top-third safety expert), got {count}"
# Safety expert 0 should be a 10% blend toward capability mean
# (nuclear default transplant_blend=0.10)
# new = 0.90 * original + 0.10 * capability_mean
blend = pipeline.transplant_blend # 0.10
expected_blend = (1.0 - blend) * orig_safety0 + blend * expected_mean
transplanted = layer.mlp.experts[0].down_proj.weight.data
assert torch.allclose(transplanted, expected_blend, atol=1e-4), (
f"Safety expert weight should be {blend:.0%} blended toward capability mean"
)
# Capability expert 2 should be unchanged
assert torch.allclose(layer.mlp.experts[2].down_proj.weight.data, cap2, atol=1e-6), (
"Capability expert should be unchanged"
)
def test_gather_state_dict_raises_on_missing_offload(self):
"""Should raise RuntimeError (not silently corrupt) when offload dir is missing."""
from obliteratus.models.loader import ModelHandle
from transformers import GPT2Config
config = GPT2Config(n_embd=8, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
# Create a fake model whose state_dict returns a meta tensor
fake_model = MagicMock()
meta_tensor = torch.empty(4, 8, device="meta")
fake_model.state_dict.return_value = {"layer.weight": meta_tensor}
handle = ModelHandle(
model=fake_model, tokenizer=MagicMock(), config=config,
model_name="test", task="causal_lm",
)
handle._offload_dir = "/nonexistent/path"
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
with pytest.raises(RuntimeError, match="bricked checkpoint"):
pipeline._gather_state_dict()
# ---------------------------------------------------------------------------
# Knee detection
# ---------------------------------------------------------------------------
class TestKneeDetection:
def test_empty_input(self):
result = AbliterationPipeline._select_layers_knee([])
assert result == []
def test_two_layers(self):
result = AbliterationPipeline._select_layers_knee([(0, 5.0), (1, 3.0)])
assert set(result) == {0, 1}
def test_clear_knee(self):
"""Layers with a sharp dropoff should be separated by knee detection."""
sorted_layers = [
(14, 10.0), (15, 9.5), (13, 9.0), # strong cluster
(16, 2.0), (12, 1.5), (17, 1.0), (11, 0.5), (18, 0.2), (10, 0.1),
]
result = AbliterationPipeline._select_layers_knee(sorted_layers)
# Should select the strong cluster (layers 14, 15, 13) and exclude weak ones
assert 14 in result
assert 15 in result
assert 13 in result
assert len(result) <= 5 # shouldn't select all 9
def test_minimum_threshold_filters_noise(self):
"""Layers below 10% of max should be filtered out."""
sorted_layers = [(0, 10.0), (1, 0.5)] # 0.5 is 5% of 10
result = AbliterationPipeline._select_layers_knee(sorted_layers)
# Layer 1 is below 10% threshold
assert 0 in result
def test_all_equal_norms(self):
"""When all norms are equal, should select all (or most)."""
sorted_layers = [(i, 5.0) for i in range(5)]
result = AbliterationPipeline._select_layers_knee(sorted_layers)
assert len(result) >= 1
# ---------------------------------------------------------------------------
# Activation collection
# ---------------------------------------------------------------------------
class TestActivationCollection:
def test_collect_activations(self, handle):
"""Test that activation collection returns correct structure."""
from obliteratus.strategies.utils import get_layer_modules
pipeline = AbliterationPipeline(model_name="test")
pipeline.handle = handle
pipeline._on_log = lambda m: None
layers = get_layer_modules(handle)
prompts = ["Hello world", "Test prompt"]
handle.tokenizer.return_value = {
"input_ids": torch.randint(0, 1000, (1, 5)),
"attention_mask": torch.ones(1, 5, dtype=torch.long),
}
activations = pipeline._collect_activations(layers, prompts, "test")
assert len(activations) == len(layers)
for idx in range(len(layers)):
assert len(activations[idx]) == len(prompts)
for act in activations[idx]:
assert act.device == torch.device("cpu")
assert act.shape[-1] == handle.hidden_size
# ---------------------------------------------------------------------------
# Distill: single direction (basic method)
# ---------------------------------------------------------------------------
class TestDistillBasic:
def test_single_direction(self, handle):
"""Basic method: single refusal direction via difference-in-means."""
from obliteratus.strategies.utils import get_layer_modules
pipeline = AbliterationPipeline(
model_name="test",
method="basic",
harmful_prompts=["bad prompt"],
harmless_prompts=["good prompt"],
)
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
_make_varied_tokenizer(handle)
pipeline._probe()
pipeline._distill()
n_layers = len(get_layer_modules(handle))
assert len(pipeline.refusal_directions) == n_layers
for idx, direction in pipeline.refusal_directions.items():
assert abs(direction.norm().item() - 1.0) < 1e-4
# Single direction: subspace should be (1, hidden_dim)
assert pipeline.refusal_subspaces[idx].shape[0] == 1
# ---------------------------------------------------------------------------
# Distill: multi-direction SVD (advanced/aggressive method)
# ---------------------------------------------------------------------------
class TestDistillSVD:
def test_multi_direction_svd(self, handle):
"""Advanced method: SVD extracts multiple refusal directions.
Note: on small models (hidden_size < 2048 or < 2B params), n_directions
is automatically capped to 2 to prevent over-ablation. The test model
(hidden_size=64, 4 layers) triggers this safeguard.
"""
from obliteratus.strategies.utils import get_layer_modules
pipeline = AbliterationPipeline(
model_name="test",
method="advanced",
harmful_prompts=["bad1", "bad2", "bad3", "bad4", "bad5"],
harmless_prompts=["good1", "good2", "good3", "good4", "good5"],
)
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
_make_varied_tokenizer(handle)
pipeline._probe()
pipeline._distill()
n_layers = len(get_layer_modules(handle))
assert len(pipeline.refusal_subspaces) == n_layers
# Small-model cap: n_directions capped to 2 for tiny test model
expected_dirs = min(2, pipeline.n_directions, 5, handle.hidden_size)
for idx, subspace in pipeline.refusal_subspaces.items():
assert subspace.shape[0] == expected_dirs
assert subspace.shape[1] == handle.hidden_size
# Primary direction should still be a unit vector
for idx, direction in pipeline.refusal_directions.items():
assert abs(direction.norm().item() - 1.0) < 1e-4
# ---------------------------------------------------------------------------
# Full pipeline: excise with different methods
# ---------------------------------------------------------------------------
class TestExcise:
def test_excise_basic(self, handle):
"""Basic method should modify weights."""
from obliteratus.strategies.utils import get_layer_modules
pipeline = AbliterationPipeline(
model_name="test",
method="basic",
harmful_prompts=["bad prompt"],
harmless_prompts=["good prompt"],
)
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
_make_varied_tokenizer(handle)
layers = get_layer_modules(handle)
original_weights = {}
for idx in range(len(layers)):
for name, param in layers[idx].named_parameters():
original_weights[(idx, name)] = param.data.clone()
pipeline._probe()
pipeline._distill()
pipeline._excise()
any_changed = False
for idx in range(len(layers)):
for name, param in layers[idx].named_parameters():
if not torch.allclose(original_weights[(idx, name)], param.data, atol=1e-6):
any_changed = True
break
assert any_changed, "Excise should modify at least some weights"
def test_excise_advanced_norm_preserving(self, handle):
"""Advanced method with norm preservation should maintain weight norms."""
from obliteratus.strategies.utils import get_layer_modules
pipeline = AbliterationPipeline(
model_name="test",
method="advanced",
harmful_prompts=["bad prompt"],
harmless_prompts=["good prompt"],
)
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
_make_varied_tokenizer(handle)
get_layer_modules(handle)
pipeline._probe()
pipeline._distill()
pipeline._excise()
# Weights should have been modified (advanced uses _project_out_advanced)
assert len(pipeline._strong_layers) > 0
# ---------------------------------------------------------------------------
# Rebirth (save)
# ---------------------------------------------------------------------------
class TestRebirth:
def test_rebirth_saves_metadata(self, handle, tmp_path):
"""Rebirth should save model and comprehensive metadata JSON."""
pipeline = AbliterationPipeline(
model_name="test-model",
output_dir=str(tmp_path / "output"),
method="advanced",
)
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
pipeline._strong_layers = [0]
pipeline._quality_metrics = {"perplexity": 8.5, "coherence": 1.0}
handle.model.save_pretrained = MagicMock()
handle.tokenizer.save_pretrained = MagicMock()
result_path = pipeline._rebirth()
assert result_path == tmp_path / "output"
assert (result_path / "abliteration_metadata.json").exists()
metadata = json.loads((result_path / "abliteration_metadata.json").read_text())
assert metadata["source_model"] == "test-model"
assert metadata["technique"] == "refusal_direction_ablation"
assert metadata["method"] == "advanced"
assert metadata["strong_layers"] == [0]
assert "method_config" in metadata
assert metadata["method_config"]["n_directions"] == METHODS["advanced"]["n_directions"]
assert metadata["method_config"]["norm_preserve"] is True
assert "references" in metadata
assert len(metadata["references"]) >= 3
assert "quality_metrics" in metadata
assert metadata["quality_metrics"]["perplexity"] == 8.5
# ---------------------------------------------------------------------------
# CLI integration
# ---------------------------------------------------------------------------
class TestCLI:
def test_abliterate_parser_with_method(self):
"""Test that the abliterate subcommand parses method correctly."""
import argparse
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
abl_parser = subparsers.add_parser("abliterate")
abl_parser.add_argument("model", type=str)
abl_parser.add_argument("--output-dir", type=str, default=None)
abl_parser.add_argument("--device", type=str, default="auto")
abl_parser.add_argument("--dtype", type=str, default="float16")
abl_parser.add_argument("--method", type=str, default="advanced",
choices=["basic", "advanced", "aggressive"])
abl_parser.add_argument("--n-directions", type=int, default=None)
abl_parser.add_argument("--regularization", type=float, default=None)
abl_parser.add_argument("--refinement-passes", type=int, default=None)
args = parser.parse_args(["abliterate", "gpt2", "--method", "aggressive", "--n-directions", "6"])
assert args.command == "abliterate"
assert args.model == "gpt2"
assert args.method == "aggressive"
assert args.n_directions == 6
assert args.dtype == "float16"
def test_default_method(self):
"""Default method should be advanced."""
import argparse
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
abl_parser = subparsers.add_parser("abliterate")
abl_parser.add_argument("model", type=str)
abl_parser.add_argument("--method", type=str, default="advanced")
args = parser.parse_args(["abliterate", "gpt2"])
assert args.method == "advanced"
# ---------------------------------------------------------------------------
# Expert-Granular Abliteration (EGA)
# ---------------------------------------------------------------------------
class TestFindRouterModule:
"""Test _find_router_module static method."""
def test_finds_gate(self):
"""Should find a router named 'gate'."""
hidden = 16
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, 4, bias=False)
self.experts = torch.nn.ModuleList()
moe = FakeMoE()
router = AbliterationPipeline._find_router_module(moe)
assert router is moe.gate
def test_finds_router(self):
"""Should find a router named 'router'."""
hidden = 16
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.router = torch.nn.Linear(hidden, 4, bias=False)
self.experts = torch.nn.ModuleList()
moe = FakeMoE()
router = AbliterationPipeline._find_router_module(moe)
assert router is moe.router
def test_auto_detects_unknown_router(self):
"""Should auto-detect a router with unusual name via heuristic."""
hidden = 16
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.moe_gate_proj = torch.nn.Linear(hidden, 4, bias=False)
self.experts = torch.nn.ModuleList()
moe = FakeMoE()
router = AbliterationPipeline._find_router_module(moe)
assert router is moe.moe_gate_proj
def test_returns_none_no_router(self):
"""Should return None when no router is found."""
class NoRouter(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)
mod = NoRouter()
assert AbliterationPipeline._find_router_module(mod) is None
class TestRouterProfilingHooks:
"""Test _install_router_profiling_hooks."""
def _make_moe_pipeline_and_layers(self, hidden=16, n_experts=4):
"""Create a pipeline with a fake MoE model for router profiling tests."""
from obliteratus.models.loader import ModelHandle
from transformers import GPT2Config
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, hidden, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
def forward(self, x):
return x
class FakeLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = torch.nn.Module()
self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
self.mlp = FakeMoE()
def forward(self, x):
return (x,)
config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
model = MagicMock()
model.parameters.return_value = iter([torch.zeros(1)])
handle = ModelHandle(model=model, tokenizer=MagicMock(), config=config, model_name="test", task="causal_lm")
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
layer = FakeLayer()
layers = torch.nn.ModuleList([layer])
# Monkey-patch get_ffn_module
import obliteratus.abliterate as abl_module
orig_get_ffn = abl_module.get_ffn_module
abl_module.get_ffn_module = lambda lay, a: lay.mlp
return pipeline, layers, layer, abl_module, orig_get_ffn
def test_hooks_installed(self):
"""Should install hooks on MoE router modules."""
pipeline, layers, layer, abl_module, orig_get_ffn = self._make_moe_pipeline_and_layers()
try:
hooks = pipeline._install_router_profiling_hooks(layers)
assert len(hooks) == 1
assert 0 in pipeline._routing_harmful
assert 0 in pipeline._routing_harmless
finally:
for h in hooks:
h.remove()
abl_module.get_ffn_module = orig_get_ffn
def test_hooks_record_logits(self):
"""Hooks should record router logits during forward passes."""
pipeline, layers, layer, abl_module, orig_get_ffn = self._make_moe_pipeline_and_layers()
try:
hooks = pipeline._install_router_profiling_hooks(layers)
# Simulate harmful forward pass
pipeline._routing_is_harmful = True
x = torch.randn(1, 5, 16)
layer.mlp.gate(x) # triggers hook
assert len(pipeline._routing_harmful[0]) == 1
assert pipeline._routing_harmful[0][0].shape[0] == 4 # n_experts
# Simulate harmless forward pass
pipeline._routing_is_harmful = False
layer.mlp.gate(x)
assert len(pipeline._routing_harmless[0]) == 1
finally:
for h in hooks:
h.remove()
abl_module.get_ffn_module = orig_get_ffn
def test_no_handle_returns_empty(self):
"""Should return empty list when handle is None."""
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline.handle = None
hooks = pipeline._install_router_profiling_hooks(torch.nn.ModuleList())
assert hooks == []
class TestComputeExpertGranularDirections:
"""Test _compute_expert_granular_directions."""
def test_computes_per_expert_directions(self):
"""Should compute per-expert refusal directions from routing data."""
hidden = 16
n_experts = 4
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
pipeline._strong_layers = [0]
torch.manual_seed(42)
# Simulate router logits: expert 0 favored for harmful, expert 3 for harmless
h_logits = []
s_logits = []
for _ in range(10):
hl = torch.randn(n_experts)
hl[0] += 2.0 # bias expert 0 for harmful
h_logits.append(hl)
sl = torch.randn(n_experts)
sl[3] += 2.0 # bias expert 3 for harmless
s_logits.append(sl)
pipeline._routing_harmful = {0: h_logits}
pipeline._routing_harmless = {0: s_logits}
# Simulate per-prompt activations with harmful/harmless separation
refusal_dir = torch.randn(hidden)
refusal_dir = refusal_dir / refusal_dir.norm()
h_acts = [torch.randn(hidden) + 1.5 * refusal_dir for _ in range(10)]
s_acts = [torch.randn(hidden) - 1.5 * refusal_dir for _ in range(10)]
pipeline._harmful_acts = {0: h_acts}
pipeline._harmless_acts = {0: s_acts}
pipeline._compute_expert_granular_directions()
# Should have computed expert directions for layer 0
assert 0 in pipeline._expert_directions
assert len(pipeline._expert_directions[0]) > 0
# Should have dynamic safety scores
assert 0 in pipeline._expert_safety_scores
scores = pipeline._expert_safety_scores[0]
assert len(scores) == n_experts
# Expert 0 should have higher safety score (more activated for harmful)
expert_0_score = next(s for eid, s in scores if eid == 0)
expert_3_score = next(s for eid, s in scores if eid == 3)
assert expert_0_score > expert_3_score, (
f"Expert 0 should have higher safety score: {expert_0_score} vs {expert_3_score}"
)
def test_directions_are_unit_vectors(self):
"""Per-expert directions should be unit normalized."""
hidden = 16
n_experts = 4
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._strong_layers = [0]
torch.manual_seed(42)
h_logits = [torch.randn(n_experts) for _ in range(10)]
s_logits = [torch.randn(n_experts) for _ in range(10)]
pipeline._routing_harmful = {0: h_logits}
pipeline._routing_harmless = {0: s_logits}
pipeline._harmful_acts = {0: [torch.randn(hidden) + torch.ones(hidden) for _ in range(10)]}
pipeline._harmless_acts = {0: [torch.randn(hidden) - torch.ones(hidden) for _ in range(10)]}
pipeline._compute_expert_granular_directions()
if 0 in pipeline._expert_directions:
for ei, d in pipeline._expert_directions[0].items():
assert abs(d.norm().item() - 1.0) < 1e-4, (
f"Expert {ei} direction norm={d.norm().item()}, expected 1.0"
)
def test_skips_when_no_routing_data(self):
"""Should skip gracefully when no routing data is available."""
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._routing_harmful = {}
pipeline._routing_harmless = {}
pipeline._compute_expert_granular_directions()
assert len(pipeline._expert_directions) == 0
def test_skips_expert_with_low_routing_weight(self):
"""Experts with insufficient routing weight should not get directions."""
hidden = 16
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._strong_layers = [0]
# Create routing logits where expert 3 is never selected (very low)
h_logits = []
s_logits = []
for _ in range(3):
hl = torch.tensor([5.0, 5.0, 5.0, -100.0]) # expert 3 never routed
h_logits.append(hl)
sl = torch.tensor([5.0, 5.0, 5.0, -100.0])
s_logits.append(sl)
pipeline._routing_harmful = {0: h_logits}
pipeline._routing_harmless = {0: s_logits}
torch.manual_seed(42)
pipeline._harmful_acts = {0: [torch.randn(hidden) for _ in range(3)]}
pipeline._harmless_acts = {0: [torch.randn(hidden) for _ in range(3)]}
pipeline._compute_expert_granular_directions()
# Expert 3 should NOT have a direction (routing weight too low)
if 0 in pipeline._expert_directions:
assert 3 not in pipeline._expert_directions[0], (
"Expert with near-zero routing weight should not get a direction"
)
class TestProjectMoEExpertsGranular:
"""Test _project_moe_experts_granular (ModuleList path)."""
def _make_direction(self, hidden_dim=16):
d = torch.randn(hidden_dim, 1)
return d / d.norm()
def test_per_expert_directions_applied(self):
"""Each expert should use its own direction when available."""
hidden = 16
n_experts = 4
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
moe = FakeMoE()
torch.manual_seed(42)
for p in moe.parameters():
p.data = torch.randn_like(p.data)
shared_dir = self._make_direction(hidden)
# Create distinct per-expert directions
expert_dirs = {}
for ei in range(n_experts):
d = torch.randn(hidden)
d = d / d.norm()
expert_dirs[ei] = d
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._expert_directions = {0: expert_dirs}
# Save originals
orig_weights = {
ei: moe.experts[ei].down_proj.weight.data.clone()
for ei in range(n_experts)
}
count = pipeline._project_moe_experts_granular(
moe, shared_dir, layer_idx=0,
)
assert count > 0, "Should project some weights"
# All experts should be modified
for ei in range(n_experts):
assert not torch.allclose(
moe.experts[ei].down_proj.weight.data, orig_weights[ei]
), f"Expert {ei} should be modified"
def test_falls_back_to_shared_direction(self):
"""Experts without per-expert direction should use shared direction."""
hidden = 16
n_experts = 4
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
moe = FakeMoE()
torch.manual_seed(42)
for p in moe.parameters():
p.data = torch.randn_like(p.data)
shared_dir = self._make_direction(hidden)
# Only expert 0 has a per-expert direction
expert_dirs = {0: torch.randn(hidden).div_(torch.randn(hidden).norm())}
expert_dirs[0] = expert_dirs[0] / expert_dirs[0].norm()
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._expert_directions = {0: expert_dirs}
orig_e1 = moe.experts[1].down_proj.weight.data.clone()
pipeline._project_moe_experts_granular(
moe, shared_dir, layer_idx=0,
)
# Experts 1,2,3 should be modified (using shared direction)
assert not torch.allclose(moe.experts[1].down_proj.weight.data, orig_e1), \
"Expert 1 should use shared direction fallback"
def test_router_uses_shared_direction(self):
"""Router should always use the shared direction, not per-expert."""
hidden = 16
n_experts = 4
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
moe = FakeMoE()
shared_dir = self._make_direction(hidden)
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._expert_directions = {0: {0: torch.randn(hidden)}}
orig_gate = moe.gate.weight.data.clone()
pipeline._project_moe_experts_granular(moe, shared_dir, layer_idx=0)
# Gate should be projected
assert not torch.allclose(moe.gate.weight.data, orig_gate), \
"Router should be projected with shared direction"
# Gate's projection onto shared direction should be near zero
proj = (moe.gate.weight.data @ shared_dir).norm().item()
assert proj < 1e-4, f"Router should have shared dir removed, proj={proj}"
def test_shared_expert_uses_shared_direction(self):
"""Shared expert should always use the shared direction."""
hidden = 16
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, 2, bias=False)
self.shared_expert = torch.nn.Module()
self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(2)])
moe = FakeMoE()
shared_dir = self._make_direction(hidden)
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._expert_directions = {0: {0: torch.randn(hidden)}}
orig_shared = moe.shared_expert.down_proj.weight.data.clone()
pipeline._project_moe_experts_granular(moe, shared_dir, layer_idx=0)
assert not torch.allclose(moe.shared_expert.down_proj.weight.data, orig_shared), \
"Shared expert should be projected"
class TestProjectFused3DGranular:
"""Test _project_fused_3d_granular for fused 3D expert tensors."""
def test_per_expert_directions_on_fused(self):
"""Each expert slice should use its own direction."""
hidden = 16
intermediate = 32
n_experts = 4
class FusedExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
container = FusedExperts()
torch.manual_seed(42)
shared_dir = torch.randn(hidden, 1)
shared_dir = shared_dir / shared_dir.norm()
# Per-expert directions
expert_dirs = {}
for ei in range(n_experts):
d = torch.randn(hidden)
d = d / d.norm()
expert_dirs[ei] = d
orig_data = container.down_proj.data.clone()
count = AbliterationPipeline._project_fused_3d_granular(
container, shared_dir, expert_dirs, ["down_proj"],
norm_preserve=False, scale=1.0,
)
assert count == n_experts, f"Should project {n_experts} experts, got {count}"
# Each expert should be modified
for ei in range(n_experts):
assert not torch.allclose(
container.down_proj.data[ei], orig_data[ei]
), f"Expert {ei} should be modified"
def test_fallback_to_shared_on_fused(self):
"""Experts without per-expert direction should use shared direction."""
hidden = 16
intermediate = 32
n_experts = 4
class FusedExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
container = FusedExperts()
torch.manual_seed(42)
shared_dir = torch.randn(hidden, 1)
shared_dir = shared_dir / shared_dir.norm()
# Only expert 0 has a direction
expert_dirs = {0: torch.randn(hidden).div_(1.0)}
expert_dirs[0] = expert_dirs[0] / expert_dirs[0].norm()
orig_data = container.down_proj.data.clone()
count = AbliterationPipeline._project_fused_3d_granular(
container, shared_dir, expert_dirs, ["down_proj"],
norm_preserve=False, scale=1.0,
)
assert count == n_experts
# All experts should be modified (experts 1-3 use shared dir)
for ei in range(n_experts):
assert not torch.allclose(
container.down_proj.data[ei], orig_data[ei]
), f"Expert {ei} should be modified"
def test_norm_preserve_on_fused(self):
"""Fused 3D with norm_preserve should maintain per-expert norms."""
hidden = 16
intermediate = 32
n_experts = 4
class FusedExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
container = FusedExperts()
torch.manual_seed(42)
shared_dir = torch.randn(hidden, 1)
shared_dir = shared_dir / shared_dir.norm()
expert_dirs = {}
for ei in range(n_experts):
d = torch.randn(hidden)
expert_dirs[ei] = d / d.norm()
orig_norms = [container.down_proj.data[i].norm().item() for i in range(n_experts)]
AbliterationPipeline._project_fused_3d_granular(
container, shared_dir, expert_dirs, ["down_proj"],
norm_preserve=True, scale=1.0,
)
for i in range(n_experts):
new_norm = container.down_proj.data[i].norm().item()
assert abs(orig_norms[i] - new_norm) < 1e-3, (
f"Expert {i} norm not preserved: {orig_norms[i]:.4f} vs {new_norm:.4f}"
)
def test_skips_non_3d_params(self):
"""Should skip parameters that are not 3-dimensional."""
hidden = 16
class FlatExperts(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Parameter(torch.randn(32, hidden))
container = FlatExperts()
shared_dir = torch.randn(hidden, 1)
shared_dir = shared_dir / shared_dir.norm()
count = AbliterationPipeline._project_fused_3d_granular(
container, shared_dir, {}, ["down_proj"],
norm_preserve=False, scale=1.0,
)
assert count == 0
class TestEGAExciseIntegration:
"""Test that EGA integrates properly in the excise stage path."""
def test_ega_pipeline_flags(self):
"""Pipeline with surgical method should enable per_expert_directions."""
pipeline = AbliterationPipeline(model_name="test", method="surgical")
assert pipeline.per_expert_directions is True
def test_ega_only_on_primary_direction(self):
"""EGA should only apply for dir_idx==0, not higher SVD directions."""
# This is enforced by the `and dir_idx == 0` check in _excise
# We verify the code structure exists
from obliteratus.abliterate import AbliterationPipeline
import inspect
source = inspect.getsource(AbliterationPipeline._excise_inner)
assert "dir_idx == 0" in source, "EGA should only apply for primary direction"
assert "_project_moe_experts_granular" in source, "EGA method should be called in excise"
def test_ega_distill_integration(self):
"""EGA should be called during distill when per_expert_directions is enabled."""
from obliteratus.abliterate import AbliterationPipeline
import inspect
source = inspect.getsource(AbliterationPipeline._distill)
assert "_compute_expert_granular_directions" in source
assert "per_expert_directions" in source
def test_nuclear_method_enables_ega(self):
"""Nuclear method should also enable per_expert_directions."""
cfg = METHODS["nuclear"]
assert cfg["per_expert_directions"] is True
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
assert pipeline.per_expert_directions is True
def test_basic_method_disables_ega(self):
"""Basic method should not enable per_expert_directions."""
cfg = METHODS["basic"]
assert cfg.get("per_expert_directions", False) is False
def test_inverted_method_enables_ega(self):
"""Inverted method should enable per_expert_directions."""
cfg = METHODS["inverted"]
assert cfg["per_expert_directions"] is True
def test_ega_with_routing_data_end_to_end(self):
"""End-to-end: EGA computes directions and granular projection modifies weights."""
hidden = 16
n_experts = 4
class FakeExpert(torch.nn.Module):
def __init__(self):
super().__init__()
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
class FakeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
moe = FakeMoE()
torch.manual_seed(42)
for p in moe.parameters():
p.data = torch.randn_like(p.data)
pipeline = AbliterationPipeline(model_name="test", method="surgical")
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
pipeline._strong_layers = [0]
# Simulate EGA routing data
h_logits = [torch.randn(n_experts) for _ in range(5)]
s_logits = [torch.randn(n_experts) for _ in range(5)]
pipeline._routing_harmful = {0: h_logits}
pipeline._routing_harmless = {0: s_logits}
# Simulate activations with clear separation
refusal_dir = torch.randn(hidden)
refusal_dir = refusal_dir / refusal_dir.norm()
pipeline._harmful_acts = {0: [torch.randn(hidden) + 2 * refusal_dir for _ in range(5)]}
pipeline._harmless_acts = {0: [torch.randn(hidden) - 2 * refusal_dir for _ in range(5)]}
# Step 1: compute EGA directions
pipeline._compute_expert_granular_directions()
assert 0 in pipeline._expert_directions
assert len(pipeline._expert_directions[0]) > 0
# Step 2: apply granular projection
shared_dir = torch.randn(hidden, 1)
shared_dir = shared_dir / shared_dir.norm()
orig_expert0 = moe.experts[0].down_proj.weight.data.clone()
count = pipeline._project_moe_experts_granular(
moe, shared_dir, layer_idx=0,
)
assert count > 0
assert not torch.allclose(moe.experts[0].down_proj.weight.data, orig_expert0), \
"Expert weights should be modified by EGA"