mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-28 22:26:15 +02:00
2635 lines
102 KiB
Python
2635 lines
102 KiB
Python
"""Tests for the SOTA abliteration pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
from obliteratus.abliterate import (
|
|
HARMFUL_PROMPTS,
|
|
HARMLESS_PROMPTS,
|
|
METHODS,
|
|
STAGES,
|
|
AbliterationPipeline,
|
|
PipelineStage,
|
|
StageResult,
|
|
)
|
|
from obliteratus.models.loader import ModelHandle
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_tiny_handle():
|
|
"""Create a minimal ModelHandle with a tiny GPT-2 for testing."""
|
|
config = GPT2Config(
|
|
vocab_size=1000,
|
|
n_positions=128,
|
|
n_embd=64,
|
|
n_layer=4,
|
|
n_head=2,
|
|
n_inner=256,
|
|
)
|
|
model = GPT2LMHeadModel(config)
|
|
model.eval()
|
|
|
|
tokenizer = MagicMock()
|
|
tokenizer.pad_token = "<pad>"
|
|
tokenizer.eos_token = "<eos>"
|
|
tokenizer.return_value = {
|
|
"input_ids": torch.randint(0, 1000, (1, 10)),
|
|
"attention_mask": torch.ones(1, 10, dtype=torch.long),
|
|
}
|
|
tokenizer.decode.return_value = "The capital of France is Paris, a beautiful city"
|
|
|
|
handle = ModelHandle(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
config=config,
|
|
model_name="gpt2-test",
|
|
task="causal_lm",
|
|
)
|
|
handle.snapshot()
|
|
return handle
|
|
|
|
|
|
def _make_varied_tokenizer(handle):
|
|
"""Set up a tokenizer mock that returns different tokens per call."""
|
|
call_count = [0]
|
|
def mock_tokenizer(prompt, **kwargs):
|
|
call_count[0] += 1
|
|
torch.manual_seed(call_count[0])
|
|
return {
|
|
"input_ids": torch.randint(0, 1000, (1, 5)),
|
|
"attention_mask": torch.ones(1, 5, dtype=torch.long),
|
|
}
|
|
handle.tokenizer.side_effect = mock_tokenizer
|
|
|
|
|
|
@pytest.fixture
|
|
def handle():
|
|
return _make_tiny_handle()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data & stage definitions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPrompts:
|
|
def test_harmful_prompts_expanded(self):
|
|
assert len(HARMFUL_PROMPTS) >= 99
|
|
|
|
def test_harmless_prompts_expanded(self):
|
|
assert len(HARMLESS_PROMPTS) >= 99
|
|
|
|
def test_prompt_lists_same_length(self):
|
|
assert len(HARMFUL_PROMPTS) == len(HARMLESS_PROMPTS)
|
|
|
|
def test_prompt_count_512(self):
|
|
"""512 prompts across 7 severity tiers."""
|
|
assert len(HARMFUL_PROMPTS) == 512
|
|
assert len(HARMLESS_PROMPTS) == 512
|
|
|
|
def test_prompt_volume_slicing(self):
|
|
"""Slicing at standard volumes gives correct counts."""
|
|
for n in (33, 66, 99, 256, 512):
|
|
assert len(HARMFUL_PROMPTS[:n]) == n
|
|
assert len(HARMLESS_PROMPTS[:n]) == n
|
|
|
|
|
|
class TestStages:
|
|
def test_six_stages(self):
|
|
assert len(STAGES) == 6
|
|
|
|
def test_stage_keys(self):
|
|
keys = [s.key for s in STAGES]
|
|
assert keys == ["summon", "probe", "distill", "excise", "verify", "rebirth"]
|
|
|
|
def test_stage_dataclass(self):
|
|
stage = PipelineStage(key="test", name="TEST", description="A test stage")
|
|
assert stage.key == "test"
|
|
assert stage.name == "TEST"
|
|
|
|
def test_stage_result_defaults(self):
|
|
result = StageResult(stage="test", status="running")
|
|
assert result.message == ""
|
|
assert result.duration == 0.0
|
|
assert result.details == {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Method presets
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMethods:
|
|
def test_methods_exist(self):
|
|
assert set(METHODS.keys()) == {"basic", "advanced", "aggressive", "informed", "surgical", "inverted", "nuclear", "optimized", "failspy", "gabliteration", "heretic", "rdo", "spectral_cascade"}
|
|
|
|
def test_basic_single_direction(self):
|
|
cfg = METHODS["basic"]
|
|
assert cfg["n_directions"] == 1
|
|
assert cfg["norm_preserve"] is False
|
|
assert cfg["regularization"] == 0.0
|
|
assert cfg["refinement_passes"] == 1
|
|
|
|
def test_advanced_multi_direction(self):
|
|
cfg = METHODS["advanced"]
|
|
assert cfg["n_directions"] > 1
|
|
assert cfg["norm_preserve"] is True
|
|
assert cfg["regularization"] > 0
|
|
assert cfg["refinement_passes"] >= 2
|
|
|
|
def test_aggressive_full_gabliteration(self):
|
|
cfg = METHODS["aggressive"]
|
|
assert cfg["n_directions"] >= 8
|
|
assert cfg["norm_preserve"] is True
|
|
assert cfg["refinement_passes"] >= 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pipeline init
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPipelineInit:
|
|
def test_default_prompts(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model")
|
|
assert pipeline.harmful_prompts == HARMFUL_PROMPTS
|
|
assert pipeline.harmless_prompts == HARMLESS_PROMPTS
|
|
|
|
def test_custom_prompts(self):
|
|
harmful = ["bad prompt"]
|
|
harmless = ["good prompt"]
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
harmful_prompts=harmful,
|
|
harmless_prompts=harmless,
|
|
)
|
|
assert pipeline.harmful_prompts == harmful
|
|
assert pipeline.harmless_prompts == harmless
|
|
|
|
def test_defaults(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model")
|
|
assert pipeline.device == "auto"
|
|
assert pipeline.dtype == "float16"
|
|
assert pipeline.output_dir == Path("abliterated")
|
|
assert pipeline.trust_remote_code is False
|
|
assert pipeline.handle is None
|
|
|
|
def test_default_method_is_advanced(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model")
|
|
assert pipeline.method == "advanced"
|
|
assert pipeline.n_directions == METHODS["advanced"]["n_directions"]
|
|
assert pipeline.norm_preserve == METHODS["advanced"]["norm_preserve"]
|
|
assert pipeline.regularization == METHODS["advanced"]["regularization"]
|
|
|
|
def test_method_basic(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model", method="basic")
|
|
assert pipeline.n_directions == 1
|
|
assert pipeline.norm_preserve is False
|
|
assert pipeline.regularization == 0.0
|
|
|
|
def test_method_aggressive(self):
|
|
pipeline = AbliterationPipeline(model_name="test-model", method="aggressive")
|
|
assert pipeline.n_directions == 8
|
|
assert pipeline.norm_preserve is True
|
|
assert pipeline.refinement_passes == 3
|
|
|
|
def test_explicit_overrides_method(self):
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
method="basic",
|
|
n_directions=6,
|
|
norm_preserve=True,
|
|
regularization=0.5,
|
|
refinement_passes=4,
|
|
)
|
|
assert pipeline.n_directions == 6
|
|
assert pipeline.norm_preserve is True
|
|
assert pipeline.regularization == 0.5
|
|
assert pipeline.refinement_passes == 4
|
|
|
|
def test_callbacks(self):
|
|
stage_results = []
|
|
log_msgs = []
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
on_stage=lambda r: stage_results.append(r),
|
|
on_log=lambda m: log_msgs.append(m),
|
|
)
|
|
pipeline.log("hello")
|
|
assert log_msgs == ["hello"]
|
|
|
|
pipeline._emit("test", "running", "msg")
|
|
assert len(stage_results) == 1
|
|
assert stage_results[0].stage == "test"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _project_out_advanced (norm-preserving + regularization)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestProjectOutAdvanced:
|
|
def test_norm_preserving(self):
|
|
"""Norm-preserving mode should keep Frobenius norm constant."""
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(4, 8, bias=False)
|
|
|
|
module = Wrapper()
|
|
torch.manual_seed(42)
|
|
module.o_proj.weight.data = torch.randn(8, 4)
|
|
original_norm = module.o_proj.weight.data.norm().item()
|
|
|
|
direction = torch.randn(4, 1)
|
|
direction = direction / direction.norm()
|
|
|
|
AbliterationPipeline._project_out_advanced(
|
|
module, direction, ["o_proj"], norm_preserve=True, regularization=0.0
|
|
)
|
|
|
|
new_norm = module.o_proj.weight.data.norm().item()
|
|
# With amplification cap (1.10x max), exact norm preservation isn't
|
|
# guaranteed on tiny matrices (hidden_dim=4) where a single direction
|
|
# removes a large fraction of energy. Verify the norm is closer to
|
|
# original than the un-preserved norm would be (i.e. cap is working).
|
|
without_preserve_norm_sq = original_norm ** 2 - (module.o_proj.weight.data @ direction).pow(2).sum().item()
|
|
# The new norm should be >= the un-preserved norm (cap restores some)
|
|
assert new_norm >= original_norm * 0.85, \
|
|
f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}"
|
|
|
|
def test_regularization_partial_removal(self):
|
|
"""Regularization should preserve some of the refusal component."""
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(4, 8, bias=False)
|
|
|
|
module_full = Wrapper()
|
|
module_reg = Wrapper()
|
|
torch.manual_seed(42)
|
|
W_orig = torch.randn(8, 4)
|
|
module_full.o_proj.weight.data = W_orig.clone()
|
|
module_reg.o_proj.weight.data = W_orig.clone()
|
|
|
|
direction = torch.randn(4, 1)
|
|
direction = direction / direction.norm()
|
|
|
|
# Full removal
|
|
AbliterationPipeline._project_out_advanced(
|
|
module_full, direction, ["o_proj"], norm_preserve=False, regularization=0.0
|
|
)
|
|
# Regularized (30% preserved)
|
|
AbliterationPipeline._project_out_advanced(
|
|
module_reg, direction, ["o_proj"], norm_preserve=False, regularization=0.3
|
|
)
|
|
|
|
W_full = module_full.o_proj.weight.data
|
|
W_reg = module_reg.o_proj.weight.data
|
|
|
|
# Full removal should have zero projection on direction
|
|
proj_full = (W_full @ direction).norm().item()
|
|
assert proj_full < 1e-4
|
|
|
|
# Regularized should have non-zero projection (30% preserved)
|
|
proj_reg = (W_reg @ direction).norm().item()
|
|
proj_orig = (W_orig @ direction).norm().item()
|
|
expected_ratio = 0.3
|
|
actual_ratio = proj_reg / proj_orig if proj_orig > 0 else 0
|
|
assert abs(actual_ratio - expected_ratio) < 0.05, \
|
|
f"Expected ~{expected_ratio:.0%} preserved, got {actual_ratio:.0%}"
|
|
|
|
def test_norm_preserving_transposed(self):
|
|
"""Norm-preserving should also work for transposed weights."""
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.c_proj = torch.nn.Linear(8, 4, bias=False)
|
|
|
|
module = Wrapper()
|
|
torch.manual_seed(42)
|
|
module.c_proj.weight.data = torch.randn(4, 8)
|
|
original_norm = module.c_proj.weight.data.norm().item()
|
|
|
|
direction = torch.randn(4, 1)
|
|
direction = direction / direction.norm()
|
|
|
|
AbliterationPipeline._project_out_advanced(
|
|
module, direction, ["c_proj"], norm_preserve=True, regularization=0.0
|
|
)
|
|
|
|
new_norm = module.c_proj.weight.data.norm().item()
|
|
# With amplification cap (1.10x max), exact norm preservation isn't
|
|
# guaranteed on tiny matrices where a single direction removes a large
|
|
# fraction of energy.
|
|
assert new_norm >= original_norm * 0.80, \
|
|
f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Full attention projection (q/k/v + o_proj)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAttentionFullProjection:
|
|
"""Test that ALL attention weight matrices are projected (not just o_proj)."""
|
|
|
|
def test_qkv_all_projected(self):
|
|
"""q_proj, k_proj, v_proj should all be projected alongside o_proj."""
|
|
hidden = 16
|
|
|
|
class FakeAttn(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.q_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
self.k_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
self.v_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
self.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
|
|
attn = FakeAttn()
|
|
torch.manual_seed(42)
|
|
for p in attn.parameters():
|
|
p.data = torch.randn_like(p.data)
|
|
|
|
originals = {
|
|
name: getattr(attn, name).weight.data.clone()
|
|
for name in ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
}
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
from obliteratus.abliterate import _ATTN_OUT_NAMES, _ATTN_IN_NAMES
|
|
count = AbliterationPipeline._project_out_advanced(
|
|
attn, d, _ATTN_OUT_NAMES + _ATTN_IN_NAMES,
|
|
)
|
|
|
|
assert count == 4, f"Should project 4 weights (q/k/v/o), got {count}"
|
|
for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
assert not torch.allclose(
|
|
getattr(attn, name).weight.data, originals[name]
|
|
), f"{name} should be modified"
|
|
|
|
def test_project_all_does_not_early_return(self):
|
|
"""_project_out_advanced should project ALL matching weights, not just first."""
|
|
hidden = 16
|
|
|
|
class FakeModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.gate_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
mod = FakeModule()
|
|
torch.manual_seed(42)
|
|
orig_up = mod.up_proj.weight.data.clone()
|
|
orig_gate = mod.gate_proj.weight.data.clone()
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
from obliteratus.abliterate import _FFN_IN_NAMES
|
|
count = AbliterationPipeline._project_out_advanced(mod, d, _FFN_IN_NAMES)
|
|
|
|
assert count == 2, f"Should project both up_proj and gate_proj, got {count}"
|
|
assert not torch.allclose(mod.up_proj.weight.data, orig_up), "up_proj should be modified"
|
|
assert not torch.allclose(mod.gate_proj.weight.data, orig_gate), "gate_proj should be modified"
|
|
|
|
def test_lm_head_projection(self):
|
|
"""lm_head should be projectable via _project_out_advanced."""
|
|
hidden = 16
|
|
vocab = 100
|
|
|
|
class FakeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lm_head = torch.nn.Linear(hidden, vocab, bias=False)
|
|
|
|
model = FakeModel()
|
|
torch.manual_seed(42)
|
|
orig = model.lm_head.weight.data.clone()
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
count = AbliterationPipeline._project_out_advanced(
|
|
model, d, ["lm_head"], regularization=0.0,
|
|
)
|
|
|
|
assert count == 1, "Should project lm_head"
|
|
assert not torch.allclose(model.lm_head.weight.data, orig), "lm_head should be modified"
|
|
# Verify refusal direction is removed from lm_head
|
|
proj = (model.lm_head.weight.data @ d).norm().item()
|
|
assert proj < 1e-4, f"Refusal direction should be removed from lm_head, proj={proj}"
|
|
|
|
|
|
class TestKneeDetectionThreshold:
|
|
"""Test that knee detection uses 5% threshold to include more layers."""
|
|
|
|
def test_five_percent_threshold_includes_more(self):
|
|
"""Layers between 5% and 10% of max should now be included."""
|
|
# Layer norms: max=10.0, then several between 5%-10%
|
|
sorted_layers = [(0, 10.0), (1, 8.0), (2, 6.0), (3, 0.7), (4, 0.6)]
|
|
selected = AbliterationPipeline._select_layers_knee(sorted_layers)
|
|
# 0.7 and 0.6 are 7% and 6% of max — should now be included (> 5% threshold)
|
|
assert 3 in selected or 4 in selected, (
|
|
f"Layers with 6-7% of max signal should be included, got {selected}"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MoE projection (router, shared expert, input/output, fused)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestProjectMoEExperts:
|
|
"""Test the full MoE projection pipeline: router, shared expert, experts."""
|
|
|
|
def _make_direction(self, hidden_dim=16):
|
|
d = torch.randn(hidden_dim, 1)
|
|
return d / d.norm()
|
|
|
|
def test_router_gate_projected(self):
|
|
"""Router/gate weight should have refusal direction removed."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=True)
|
|
self.experts = torch.nn.ModuleList([
|
|
self._make_expert() for _ in range(n_experts)
|
|
])
|
|
|
|
@staticmethod
|
|
def _make_expert():
|
|
m = torch.nn.Module()
|
|
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
m.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
return m
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
W_gate_orig = moe.gate.weight.data.clone()
|
|
|
|
count = AbliterationPipeline._project_moe_experts(moe, d)
|
|
assert count > 0
|
|
|
|
# Gate weight should have been modified
|
|
assert not torch.allclose(moe.gate.weight.data, W_gate_orig), \
|
|
"Router/gate weights should be projected"
|
|
|
|
# The gate weight's projection onto the direction should be ~0
|
|
proj = (moe.gate.weight.data @ d).norm().item()
|
|
assert proj < 1e-4, f"Gate should have no component along refusal dir, got {proj}"
|
|
|
|
def test_shared_expert_projected(self):
|
|
"""Shared expert (always-on) should have both input and output projected."""
|
|
hidden = 16
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, 2, bias=False)
|
|
self.shared_expert = torch.nn.Module()
|
|
self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.experts = torch.nn.ModuleList([
|
|
self._make_expert() for _ in range(2)
|
|
])
|
|
|
|
@staticmethod
|
|
def _make_expert():
|
|
m = torch.nn.Module()
|
|
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
m.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
return m
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
shared_down_orig = moe.shared_expert.down_proj.weight.data.clone()
|
|
shared_up_orig = moe.shared_expert.up_proj.weight.data.clone()
|
|
|
|
count = AbliterationPipeline._project_moe_experts(moe, d)
|
|
assert count > 0
|
|
|
|
# Both shared expert output AND input projections should be modified
|
|
assert not torch.allclose(moe.shared_expert.down_proj.weight.data, shared_down_orig), \
|
|
"Shared expert output (down_proj) should be projected"
|
|
assert not torch.allclose(moe.shared_expert.up_proj.weight.data, shared_up_orig), \
|
|
"Shared expert input (up_proj) should be projected"
|
|
|
|
def test_expert_input_projections_projected(self):
|
|
"""Expert input projections (up_proj, gate_proj) should also be modified."""
|
|
hidden = 16
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.gate_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(2)])
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
up_orig = moe.experts[0].up_proj.weight.data.clone()
|
|
|
|
count = AbliterationPipeline._project_moe_experts(moe, d)
|
|
|
|
# Each expert contributes 2 projections (output + input)
|
|
# 2 experts * 2 = 4 minimum
|
|
assert count >= 4, f"Expected >= 4 projections (out+in per expert), got {count}"
|
|
|
|
assert not torch.allclose(moe.experts[0].up_proj.weight.data, up_orig), \
|
|
"Expert input (up_proj) should be projected"
|
|
|
|
def test_fused_3d_output_and_input(self):
|
|
"""Fused 3D parameter patterns (GPT-OSS style) should project both directions."""
|
|
hidden = 16
|
|
intermediate = 32
|
|
n_experts = 4
|
|
|
|
class FusedExperts(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
|
|
self.up_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.experts = FusedExperts()
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
down_orig = moe.experts.down_proj.data.clone()
|
|
up_orig = moe.experts.up_proj.data.clone()
|
|
|
|
count = AbliterationPipeline._project_moe_experts(moe, d)
|
|
|
|
# 4 experts output + 4 experts input = 8
|
|
assert count == 8, f"Expected 8 fused projections, got {count}"
|
|
|
|
assert not torch.allclose(moe.experts.down_proj.data, down_orig), \
|
|
"Fused output (down_proj) should be projected"
|
|
assert not torch.allclose(moe.experts.up_proj.data, up_orig), \
|
|
"Fused input (up_proj) should be projected"
|
|
|
|
def test_fused_3d_norm_preserve(self):
|
|
"""Fused 3D projections should preserve norms when requested."""
|
|
hidden = 16
|
|
intermediate = 32
|
|
n_experts = 4
|
|
|
|
class FusedExperts(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.experts = FusedExperts()
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
|
|
# Record per-expert norms before
|
|
orig_norms = [moe.experts.down_proj.data[i].norm().item() for i in range(n_experts)]
|
|
|
|
AbliterationPipeline._project_moe_experts(moe, d, norm_preserve=True)
|
|
|
|
# Check per-expert norms preserved
|
|
for i in range(n_experts):
|
|
new_norm = moe.experts.down_proj.data[i].norm().item()
|
|
assert abs(orig_norms[i] - new_norm) < 1e-3, \
|
|
f"Expert {i} norm not preserved: {orig_norms[i]:.4f} vs {new_norm:.4f}"
|
|
|
|
def test_no_experts_returns_zero(self):
|
|
"""Module without experts attribute should return 0."""
|
|
class NoMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mlp = torch.nn.Linear(16, 32)
|
|
|
|
moe = NoMoE()
|
|
d = self._make_direction(16)
|
|
assert AbliterationPipeline._project_moe_experts(moe, d) == 0
|
|
|
|
def test_router_bias_projected(self):
|
|
"""Router bias should be projected when project_biases=True."""
|
|
hidden = 16
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, 4, bias=True)
|
|
self.experts = torch.nn.ModuleList([
|
|
self._make_expert() for _ in range(4)
|
|
])
|
|
|
|
@staticmethod
|
|
def _make_expert():
|
|
m = torch.nn.Module()
|
|
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
return m
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
bias_orig = moe.gate.bias.data.clone()
|
|
|
|
count = AbliterationPipeline._project_moe_experts(moe, d, project_biases=True)
|
|
|
|
# Gate has 4 outputs (num_experts), direction has 16 dims
|
|
# bias shape (4,) != direction shape (16,), so bias won't match.
|
|
# This is correct: router bias is (num_experts,), not (hidden_dim,),
|
|
# so _project_bias won't modify it (shape mismatch is expected).
|
|
assert torch.allclose(moe.gate.bias.data, bias_orig), (
|
|
"Router bias should be unchanged when shape mismatches direction"
|
|
)
|
|
assert isinstance(count, int)
|
|
assert count > 0 # expert weights should still be projected
|
|
|
|
def test_router_auto_detection_fallback(self):
|
|
"""Unknown router name should be auto-detected and projected."""
|
|
import warnings as w
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Unusual router name not in _ROUTER_NAMES
|
|
self.moe_gate_proj = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([
|
|
self._make_expert() for _ in range(n_experts)
|
|
])
|
|
|
|
@staticmethod
|
|
def _make_expert():
|
|
m = torch.nn.Module()
|
|
m.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
return m
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
gate_orig = moe.moe_gate_proj.weight.data.clone()
|
|
|
|
with w.catch_warnings(record=True) as caught:
|
|
w.simplefilter("always")
|
|
AbliterationPipeline._project_moe_experts(moe, d)
|
|
|
|
# Should auto-detect and project the unusual router name
|
|
assert not torch.allclose(moe.moe_gate_proj.weight.data, gate_orig), \
|
|
"Auto-detected router should be projected"
|
|
|
|
# Should emit a warning about the auto-detection
|
|
auto_detect_warnings = [
|
|
x for x in caught
|
|
if "auto-detected" in str(x.message)
|
|
]
|
|
assert len(auto_detect_warnings) > 0, "Should warn about auto-detected router"
|
|
|
|
def test_full_moe_all_components(self):
|
|
"""End-to-end: all MoE components should be modified together."""
|
|
hidden = 16
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, 4, bias=False)
|
|
self.shared_expert = torch.nn.Module()
|
|
self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(4)])
|
|
|
|
moe = FakeMoE()
|
|
d = self._make_direction(hidden)
|
|
|
|
count = AbliterationPipeline._project_moe_experts(moe, d)
|
|
|
|
# Expected: 1 (gate) + 2 (shared out+in) + 4*2 (expert out+in) = 11
|
|
assert count == 11, f"Expected 11 total projections, got {count}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SOTA technique #1: Safety-neuron masking (GateBreaker-style z-score)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSafetyNeuronMasking:
|
|
def test_outlier_neurons_zeroed(self):
|
|
"""Neurons with outsized refusal projection should be zeroed."""
|
|
hidden = 16
|
|
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 64, bias=False)
|
|
|
|
module = Wrapper()
|
|
torch.manual_seed(42)
|
|
# Inject a few rows with very high projection along direction
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
# Make rows 0,1,2 have huge projection (outliers)
|
|
for i in range(3):
|
|
module.down_proj.weight.data[i] = d.squeeze() * 10.0
|
|
|
|
n_masked = AbliterationPipeline._mask_safety_neurons(
|
|
module, d, ["down_proj"], z_threshold=2.0,
|
|
)
|
|
|
|
assert n_masked >= 3, f"Expected >= 3 masked neurons, got {n_masked}"
|
|
# Masked rows should be zero
|
|
for i in range(3):
|
|
assert module.down_proj.weight.data[i].abs().max().item() < 1e-6
|
|
|
|
def test_no_outliers_no_masking(self):
|
|
"""When all neurons have similar projection, none should be masked."""
|
|
hidden = 16
|
|
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
module = Wrapper()
|
|
# Uniform weights — no outliers
|
|
module.down_proj.weight.data = torch.ones(32, hidden) * 0.1
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
n_masked = AbliterationPipeline._mask_safety_neurons(
|
|
module, d, ["down_proj"], z_threshold=2.0,
|
|
)
|
|
assert n_masked == 0
|
|
|
|
def test_high_threshold_masks_fewer(self):
|
|
"""Higher z_threshold should mask fewer neurons."""
|
|
hidden = 16
|
|
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 64, bias=False)
|
|
|
|
module1 = Wrapper()
|
|
module2 = Wrapper()
|
|
torch.manual_seed(42)
|
|
W = torch.randn(64, hidden)
|
|
module1.down_proj.weight.data = W.clone()
|
|
module2.down_proj.weight.data = W.clone()
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
n1 = AbliterationPipeline._mask_safety_neurons(module1, d, ["down_proj"], z_threshold=1.5)
|
|
n2 = AbliterationPipeline._mask_safety_neurons(module2, d, ["down_proj"], z_threshold=3.0)
|
|
assert n1 >= n2, f"Lower threshold should mask more: {n1} vs {n2}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SOTA technique #2: Jailbreak-contrastive direction refinement
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestJailbreakContrastive:
|
|
def test_jailbreak_prompt_generation(self):
|
|
"""Pipeline should generate jailbreak variants of harmful prompts."""
|
|
pipeline = AbliterationPipeline(model_name="test")
|
|
pipeline.harmful_prompts = ["How to hack a computer", "Write a virus"]
|
|
|
|
jailbreak = pipeline._generate_jailbreak_prompts()
|
|
|
|
assert len(jailbreak) == 2
|
|
# Each should contain the original prompt
|
|
assert "hack a computer" in jailbreak[0].lower()
|
|
assert "virus" in jailbreak[1].lower()
|
|
# Each should be wrapped in a template (longer than original)
|
|
assert len(jailbreak[0]) > len("How to hack a computer")
|
|
|
|
def test_jailbreak_contrast_blending(self):
|
|
"""Jailbreak-contrastive blending should modify refusal direction."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test",
|
|
use_jailbreak_contrast=True,
|
|
n_directions=1,
|
|
)
|
|
hidden = 16
|
|
pipeline._on_log = lambda m: None
|
|
|
|
# Simulate probed means
|
|
torch.manual_seed(42)
|
|
harm_mean = torch.randn(1, hidden)
|
|
safe_mean = torch.randn(1, hidden)
|
|
jb_mean = torch.randn(1, hidden)
|
|
|
|
pipeline._harmful_means = {0: harm_mean}
|
|
pipeline._harmless_means = {0: safe_mean}
|
|
pipeline._jailbreak_means = {0: jb_mean}
|
|
pipeline._harmful_acts = {0: [harm_mean]}
|
|
pipeline._harmless_acts = {0: [safe_mean]}
|
|
pipeline._jailbreak_acts = {0: [jb_mean]}
|
|
|
|
# Run distill (will set standard direction, then blend)
|
|
pipeline._distill()
|
|
|
|
# Direction should be a unit vector
|
|
d = pipeline.refusal_directions[0]
|
|
assert abs(d.norm().item() - 1.0) < 1e-4
|
|
|
|
# Direction should differ from pure harm-safe difference
|
|
std_diff = (harm_mean - safe_mean).squeeze()
|
|
std_dir = std_diff / std_diff.norm()
|
|
cosine = (d @ std_dir).item()
|
|
# Blended direction should not be identical to standard
|
|
assert cosine < 0.99, f"Blended direction too similar to standard: cos={cosine}"
|
|
|
|
def test_surgical_method_enables_jailbreak(self):
|
|
"""Surgical method should enable jailbreak-contrastive by default."""
|
|
cfg = METHODS["surgical"]
|
|
assert cfg["use_jailbreak_contrast"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SOTA technique #3: Layer-adaptive projection strength
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLayerAdaptiveStrength:
|
|
def test_layer_weights_computed(self):
|
|
"""Layer-adaptive weights should be proportional to refusal signal."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test",
|
|
layer_adaptive_strength=True,
|
|
n_directions=1,
|
|
)
|
|
hidden = 16
|
|
pipeline._on_log = lambda m: None
|
|
|
|
# Simulate: layer 0 has strong signal, layer 1 weak
|
|
torch.manual_seed(42)
|
|
strong_diff = torch.randn(1, hidden) * 10.0
|
|
weak_diff = torch.randn(1, hidden) * 1.0
|
|
zero_mean = torch.zeros(1, hidden)
|
|
|
|
pipeline._harmful_means = {0: strong_diff, 1: weak_diff}
|
|
pipeline._harmless_means = {0: zero_mean, 1: zero_mean}
|
|
pipeline._harmful_acts = {0: [strong_diff], 1: [weak_diff]}
|
|
pipeline._harmless_acts = {0: [zero_mean], 1: [zero_mean]}
|
|
|
|
pipeline._distill()
|
|
|
|
# Layer weights should exist for strong layers
|
|
assert len(pipeline._layer_excise_weights) > 0
|
|
# Strongest layer should have weight ~1.0
|
|
max_weight = max(pipeline._layer_excise_weights.values())
|
|
assert max_weight > 0.9, f"Max weight should be ~1.0, got {max_weight}"
|
|
|
|
def test_surgical_method_enables_adaptive(self):
|
|
"""Surgical method should enable layer-adaptive by default."""
|
|
cfg = METHODS["surgical"]
|
|
assert cfg["layer_adaptive_strength"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SOTA technique #5: Attention head surgery
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAttentionHeadSurgery:
|
|
def test_head_selective_projection(self):
|
|
"""Selective head projection should only modify targeted head rows."""
|
|
hidden = 16
|
|
n_heads = 4
|
|
head_dim = hidden // n_heads
|
|
|
|
class FakeAttn(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
|
|
attn = FakeAttn()
|
|
torch.manual_seed(42)
|
|
W_orig = attn.o_proj.weight.data.clone()
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
# Head scores: head 0 is top safety head, head 3 is lowest
|
|
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
|
|
|
|
n_modified = AbliterationPipeline._project_head_selective(
|
|
attn, d, head_scores, n_heads=n_heads, head_fraction=0.25,
|
|
)
|
|
|
|
assert n_modified >= 1, "Should modify at least 1 head"
|
|
|
|
W_new = attn.o_proj.weight.data
|
|
# Head 0 columns (targeted) should be modified
|
|
assert not torch.allclose(
|
|
W_new[:, 0:head_dim], W_orig[:, 0:head_dim]
|
|
), "Targeted head 0 should be modified"
|
|
|
|
# Head 3 columns (NOT targeted) should be untouched
|
|
assert torch.allclose(
|
|
W_new[:, 3*head_dim:4*head_dim],
|
|
W_orig[:, 3*head_dim:4*head_dim],
|
|
), "Non-targeted head 3 should be untouched"
|
|
|
|
def test_head_surgery_norm_preserve(self):
|
|
"""Head surgery with norm_preserve should maintain per-head norms."""
|
|
hidden = 16
|
|
n_heads = 4
|
|
head_dim = hidden // n_heads
|
|
|
|
class FakeAttn(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
|
|
attn = FakeAttn()
|
|
torch.manual_seed(42)
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
orig_norms = [
|
|
attn.o_proj.weight.data[:, h*head_dim:(h+1)*head_dim].norm().item()
|
|
for h in range(n_heads)
|
|
]
|
|
|
|
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
|
|
AbliterationPipeline._project_head_selective(
|
|
attn, d, head_scores, n_heads=n_heads,
|
|
head_fraction=0.5, norm_preserve=True,
|
|
)
|
|
|
|
# Targeted heads should have preserved norms
|
|
for h in range(2): # top 50% = 2 heads
|
|
new_norm = attn.o_proj.weight.data[:, h*head_dim:(h+1)*head_dim].norm().item()
|
|
assert abs(orig_norms[h] - new_norm) < 1e-3, \
|
|
f"Head {h} norm not preserved: {orig_norms[h]:.4f} vs {new_norm:.4f}"
|
|
|
|
def test_head_surgery_non_square_gqa(self):
|
|
"""Head surgery should work for GQA models with non-square o_proj (attn_dim != hidden_dim)."""
|
|
hidden_dim = 12 # model hidden dimension
|
|
attn_dim = 32 # attention dimension (n_heads * head_dim_attn)
|
|
n_heads = 4
|
|
head_dim_attn = attn_dim // n_heads # 8
|
|
|
|
class FakeAttnGQA(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# o_proj maps attn_dim -> hidden_dim
|
|
# nn.Linear weight shape: (hidden_dim, attn_dim) = (12, 32)
|
|
self.o_proj = torch.nn.Linear(attn_dim, hidden_dim, bias=False)
|
|
|
|
attn = FakeAttnGQA()
|
|
torch.manual_seed(42)
|
|
attn.o_proj.weight.data = torch.randn(hidden_dim, attn_dim)
|
|
W_orig = attn.o_proj.weight.data.clone()
|
|
|
|
d = torch.randn(hidden_dim, 1)
|
|
d = d / d.norm()
|
|
|
|
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
|
|
|
|
n_modified = AbliterationPipeline._project_head_selective(
|
|
attn, d, head_scores, n_heads=n_heads, head_fraction=0.25,
|
|
)
|
|
|
|
assert n_modified >= 1, "Should modify at least 1 head"
|
|
|
|
W_new = attn.o_proj.weight.data
|
|
# Head 0 columns (targeted) should be modified
|
|
assert not torch.allclose(
|
|
W_new[:, 0:head_dim_attn], W_orig[:, 0:head_dim_attn]
|
|
), "Targeted head 0 should be modified"
|
|
|
|
# Head 3 columns (NOT targeted) should be untouched
|
|
assert torch.allclose(
|
|
W_new[:, 3*head_dim_attn:4*head_dim_attn],
|
|
W_orig[:, 3*head_dim_attn:4*head_dim_attn],
|
|
), "Non-targeted head 3 should be untouched"
|
|
|
|
def test_head_surgery_gqa_norm_preserve(self):
|
|
"""Head surgery on GQA non-square o_proj with norm_preserve."""
|
|
hidden_dim = 12
|
|
attn_dim = 32
|
|
n_heads = 4
|
|
head_dim_attn = attn_dim // n_heads
|
|
|
|
class FakeAttnGQA(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(attn_dim, hidden_dim, bias=False)
|
|
|
|
attn = FakeAttnGQA()
|
|
torch.manual_seed(42)
|
|
attn.o_proj.weight.data = torch.randn(hidden_dim, attn_dim)
|
|
|
|
d = torch.randn(hidden_dim, 1)
|
|
d = d / d.norm()
|
|
|
|
orig_norms = [
|
|
attn.o_proj.weight.data[:, h*head_dim_attn:(h+1)*head_dim_attn].norm().item()
|
|
for h in range(n_heads)
|
|
]
|
|
|
|
head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)]
|
|
AbliterationPipeline._project_head_selective(
|
|
attn, d, head_scores, n_heads=n_heads,
|
|
head_fraction=0.5, norm_preserve=True,
|
|
)
|
|
|
|
for h in range(2): # top 50% = 2 heads
|
|
new_norm = attn.o_proj.weight.data[:, h*head_dim_attn:(h+1)*head_dim_attn].norm().item()
|
|
assert abs(orig_norms[h] - new_norm) < 1e-3, \
|
|
f"GQA head {h} norm not preserved: {orig_norms[h]:.4f} vs {new_norm:.4f}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SOTA technique #6: SAE feature-level abliteration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSAEAbliteration:
|
|
def test_sae_train_and_reconstruct(self):
|
|
"""SAE should train and reconstruct activations."""
|
|
from obliteratus.analysis.sae_abliteration import train_sae
|
|
|
|
hidden = 32
|
|
# Generate synthetic activations
|
|
torch.manual_seed(42)
|
|
acts = [torch.randn(hidden) for _ in range(64)]
|
|
|
|
sae = train_sae(acts, hidden, expansion=2, n_epochs=10, lr=1e-3)
|
|
|
|
# Forward pass should work
|
|
x = torch.randn(1, hidden)
|
|
x_hat, z = sae(x)
|
|
assert x_hat.shape == x.shape
|
|
assert z.shape == (1, 2 * hidden) # expansion=2
|
|
|
|
# Z should be sparse (ReLU activation)
|
|
assert (z == 0).float().mean() > 0.3, "Features should be sparse"
|
|
|
|
def test_refusal_feature_identification(self):
|
|
"""SAE should identify features that differ between harmful/harmless."""
|
|
from obliteratus.analysis.sae_abliteration import (
|
|
train_sae, identify_refusal_features,
|
|
)
|
|
|
|
hidden = 32
|
|
torch.manual_seed(42)
|
|
|
|
# Create activations with clear harmful/harmless separation
|
|
refusal_dir = torch.randn(hidden)
|
|
refusal_dir = refusal_dir / refusal_dir.norm()
|
|
|
|
harmful_acts = [torch.randn(hidden) + 2.0 * refusal_dir for _ in range(32)]
|
|
harmless_acts = [torch.randn(hidden) - 2.0 * refusal_dir for _ in range(32)]
|
|
all_acts = harmful_acts + harmless_acts
|
|
|
|
sae = train_sae(all_acts, hidden, expansion=2, n_epochs=30, lr=3e-4)
|
|
result = identify_refusal_features(
|
|
sae, harmful_acts, harmless_acts, layer_idx=0, top_k=4,
|
|
)
|
|
|
|
assert result.n_refusal_features == 4
|
|
assert result.sae_directions.shape == (4, hidden)
|
|
assert result.variance_explained > 0.0
|
|
# SAE directions should have some alignment with the actual refusal direction
|
|
best_cos = max(
|
|
abs((result.sae_directions[i] @ refusal_dir).item())
|
|
for i in range(result.sae_directions.shape[0])
|
|
)
|
|
assert best_cos > 0.1, f"SAE should find direction aligned with refusal: best_cos={best_cos}"
|
|
|
|
def test_sae_directions_unit_norm(self):
|
|
"""SAE-derived directions should be unit normalized."""
|
|
from obliteratus.analysis.sae_abliteration import (
|
|
train_sae, identify_refusal_features,
|
|
)
|
|
|
|
hidden = 16
|
|
torch.manual_seed(42)
|
|
harmful = [torch.randn(hidden) + torch.ones(hidden) for _ in range(16)]
|
|
harmless = [torch.randn(hidden) - torch.ones(hidden) for _ in range(16)]
|
|
|
|
sae = train_sae(harmful + harmless, hidden, expansion=2, n_epochs=10)
|
|
result = identify_refusal_features(sae, harmful, harmless, 0, top_k=3)
|
|
|
|
for i in range(result.sae_directions.shape[0]):
|
|
norm = result.sae_directions[i].norm().item()
|
|
assert abs(norm - 1.0) < 1e-3, f"Direction {i} norm={norm}, expected 1.0"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Surgical method preset
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSurgicalMethod:
|
|
def test_surgical_enables_all_sota(self):
|
|
"""Surgical method should enable all 6 SOTA techniques."""
|
|
cfg = METHODS["surgical"]
|
|
assert cfg["use_jailbreak_contrast"] is True
|
|
assert cfg["layer_adaptive_strength"] is True
|
|
assert cfg["safety_neuron_masking"] is True
|
|
assert cfg["per_expert_directions"] is True
|
|
assert cfg["attention_head_surgery"] is True
|
|
assert cfg["use_sae_features"] is True
|
|
|
|
def test_basic_disables_all_sota(self):
|
|
"""Basic method should not enable SOTA techniques (no keys or False)."""
|
|
cfg = METHODS["basic"]
|
|
assert cfg.get("use_jailbreak_contrast", False) is False
|
|
assert cfg.get("layer_adaptive_strength", False) is False
|
|
assert cfg.get("safety_neuron_masking", False) is False
|
|
|
|
def test_pipeline_init_surgical(self):
|
|
"""Pipeline initialized with surgical method should have all flags set."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
assert pipeline.use_jailbreak_contrast is True
|
|
assert pipeline.layer_adaptive_strength is True
|
|
assert pipeline.safety_neuron_masking is True
|
|
assert pipeline.per_expert_directions is True
|
|
assert pipeline.attention_head_surgery is True
|
|
assert pipeline.use_sae_features is True
|
|
|
|
def test_pipeline_init_explicit_override(self):
|
|
"""Explicit params should override method defaults."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test", method="surgical",
|
|
safety_neuron_masking=False,
|
|
)
|
|
assert pipeline.safety_neuron_masking is False
|
|
assert pipeline.use_jailbreak_contrast is True # rest still from surgical
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Inverted method (semantic refusal inversion)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestInvertedMethod:
|
|
def test_inverted_preset_config(self):
|
|
"""Inverted method preset should enable inversion flag."""
|
|
cfg = METHODS["inverted"]
|
|
assert cfg["invert_refusal"] is True
|
|
assert cfg["n_directions"] == 8
|
|
assert cfg["use_jailbreak_contrast"] is True
|
|
|
|
def test_surgical_does_not_invert(self):
|
|
"""Surgical method should NOT enable inversion by default."""
|
|
cfg = METHODS["surgical"]
|
|
assert cfg.get("invert_refusal", False) is False
|
|
|
|
def test_pipeline_init_inverted(self):
|
|
"""Pipeline initialized with inverted method should have flag set."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="inverted")
|
|
assert pipeline.invert_refusal is True
|
|
assert pipeline.use_jailbreak_contrast is True
|
|
assert pipeline.safety_neuron_masking is False # zeroing + reflection is destructive
|
|
|
|
def test_pipeline_invert_explicit_override(self):
|
|
"""Explicit invert_refusal param should override method default."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test", method="surgical", invert_refusal=True,
|
|
)
|
|
assert pipeline.invert_refusal is True
|
|
|
|
pipeline2 = AbliterationPipeline(
|
|
model_name="test", method="inverted", invert_refusal=False,
|
|
)
|
|
assert pipeline2.invert_refusal is False
|
|
|
|
def test_reflection_math(self):
|
|
"""2x projection (reflection) should negate the refusal component."""
|
|
hidden = 16
|
|
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
module = Wrapper()
|
|
torch.manual_seed(42)
|
|
W_orig = module.o_proj.weight.data.clone()
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
# Original projection onto d
|
|
orig_proj = (W_orig @ d).squeeze()
|
|
|
|
# Reflection: regularization=-1.0 → scale=2.0
|
|
AbliterationPipeline._project_out_advanced(
|
|
module, d, ["o_proj"], regularization=-1.0,
|
|
)
|
|
|
|
W_reflected = module.o_proj.weight.data
|
|
new_proj = (W_reflected @ d).squeeze()
|
|
|
|
# After reflection, projection should be NEGATED (sign flipped)
|
|
assert torch.allclose(new_proj, -orig_proj, atol=1e-4), (
|
|
f"Reflected projection should be negated: expected ~{-orig_proj[:3]} got {new_proj[:3]}"
|
|
)
|
|
|
|
def test_reflection_preserves_orthogonal_component(self):
|
|
"""Reflection should not change the component perpendicular to d."""
|
|
hidden = 8
|
|
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(hidden, 16, bias=False)
|
|
|
|
module = Wrapper()
|
|
torch.manual_seed(42)
|
|
W_orig = module.o_proj.weight.data.clone()
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
# Compute original orthogonal component
|
|
orig_d_component = (W_orig @ d) @ d.T # rank-1 matrix: projection onto d
|
|
orig_ortho = W_orig - orig_d_component # everything except d-component
|
|
|
|
AbliterationPipeline._project_out_advanced(
|
|
module, d, ["o_proj"], regularization=-1.0,
|
|
)
|
|
|
|
W_reflected = module.o_proj.weight.data
|
|
new_d_component = (W_reflected @ d) @ d.T
|
|
new_ortho = W_reflected - new_d_component
|
|
|
|
# Orthogonal component should be unchanged
|
|
assert torch.allclose(orig_ortho, new_ortho, atol=1e-4), (
|
|
"Reflection should preserve orthogonal component"
|
|
)
|
|
|
|
def test_moe_expert_safety_classification(self):
|
|
"""_identify_safety_experts should classify experts by router affinity."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([
|
|
torch.nn.Linear(hidden, hidden) for _ in range(n_experts)
|
|
])
|
|
|
|
class FakeLayer(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.self_attn = torch.nn.Module()
|
|
self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
self.mlp = FakeMoE()
|
|
|
|
from obliteratus.models.loader import ModelHandle
|
|
from unittest.mock import MagicMock
|
|
from transformers import GPT2Config
|
|
|
|
config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
|
|
model = MagicMock()
|
|
model.parameters.return_value = iter([torch.zeros(1)])
|
|
|
|
handle = ModelHandle(
|
|
model=model, tokenizer=MagicMock(),
|
|
config=config, model_name="test", task="causal_lm",
|
|
)
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="inverted")
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
|
|
# Set up fake layer and direction
|
|
layer = FakeLayer()
|
|
torch.manual_seed(42)
|
|
|
|
# Make router weight so expert 0 has highest affinity for d
|
|
d = torch.randn(hidden)
|
|
d = d / d.norm()
|
|
# Set router weights: expert 0 aligned with d, expert 3 anti-aligned
|
|
layer.mlp.gate.weight.data[0] = d * 5.0
|
|
layer.mlp.gate.weight.data[1] = d * 1.0
|
|
layer.mlp.gate.weight.data[2] = d * -1.0
|
|
layer.mlp.gate.weight.data[3] = d * -5.0
|
|
|
|
# Mock get_layer_modules to return our fake layer
|
|
import obliteratus.abliterate as abl_module
|
|
orig_get_layers = abl_module.get_layer_modules
|
|
orig_get_ffn = abl_module.get_ffn_module
|
|
abl_module.get_layer_modules = lambda h: [layer]
|
|
abl_module.get_ffn_module = lambda lay, a: lay.mlp
|
|
try:
|
|
pipeline.refusal_directions = {0: d}
|
|
pipeline._strong_layers = [0]
|
|
pipeline._identify_safety_experts()
|
|
finally:
|
|
abl_module.get_layer_modules = orig_get_layers
|
|
abl_module.get_ffn_module = orig_get_ffn
|
|
|
|
assert 0 in pipeline._expert_safety_scores
|
|
scores = pipeline._expert_safety_scores[0]
|
|
# Expert 0 should be highest safety affinity
|
|
assert scores[0][0] == 0, f"Expert 0 should be top safety, got {scores[0]}"
|
|
# Expert 3 should be lowest
|
|
assert scores[-1][0] == 3, f"Expert 3 should be lowest, got {scores[-1]}"
|
|
|
|
def test_moe_inverted_excision_selective(self):
|
|
"""Inverted MoE excision should reflect safety experts and remove from capability."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
moe = FakeMoE()
|
|
torch.manual_seed(42)
|
|
for p in moe.parameters():
|
|
p.data = torch.randn_like(p.data)
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
# Set up safety scores: experts 0,1 are safety, 2,3 are capability
|
|
pipeline = AbliterationPipeline(model_name="test", method="inverted")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
pipeline._expert_safety_scores = {
|
|
0: [(0, 5.0), (1, 3.0), (2, -1.0), (3, -3.0)]
|
|
}
|
|
|
|
orig_router = moe.gate.weight.data.clone()
|
|
|
|
count = pipeline._project_moe_experts_inverted(
|
|
moe, d, 0, norm_preserve=False, project_biases=False,
|
|
)
|
|
|
|
assert count > 0, "Should project some weights"
|
|
|
|
# Router should be reflected (capped at 1.5x to prevent extreme logits
|
|
# that cause CUDA illegal memory access in batched expert forward).
|
|
# With router_reg = max(reflect_reg, -0.5) → scale = 1.5:
|
|
# new_proj ≈ orig_proj - 1.5 * orig_proj = -0.5 * orig_proj
|
|
# Additionally, _stabilize_router_weights clamps outliers, so we
|
|
# verify the sign is flipped and magnitude is substantial.
|
|
router_proj = (moe.gate.weight.data @ d.squeeze()).squeeze()
|
|
orig_router_proj = (orig_router @ d.squeeze()).squeeze()
|
|
cosine = torch.nn.functional.cosine_similarity(
|
|
router_proj.unsqueeze(0), -orig_router_proj.unsqueeze(0),
|
|
)
|
|
assert cosine > 0.5, (
|
|
f"Router projection should be at least partially reflected, cosine={cosine.item():.3f}"
|
|
)
|
|
|
|
# Safety expert 0: should be reflected (projection negated)
|
|
e0_proj = (moe.experts[0].down_proj.weight.data @ d).norm()
|
|
# After reflection the projection doesn't go to zero — it negates
|
|
assert e0_proj > 1e-4, "Safety expert should have non-zero projection (reflected, not removed)"
|
|
|
|
# Capability expert 3: should have projection removed (near zero)
|
|
e3_proj = (moe.experts[3].down_proj.weight.data @ d).norm().item()
|
|
assert e3_proj < 1e-3, f"Capability expert should have projection removed, got {e3_proj}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Nuclear method
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestNuclearMethod:
|
|
def test_nuclear_preset_config(self):
|
|
"""Nuclear method should match inverted baseline + permanent weight techniques."""
|
|
cfg = METHODS["nuclear"]
|
|
assert cfg["invert_refusal"] is True
|
|
assert cfg["n_directions"] == 4 # fewer than inverted to avoid over-ablation
|
|
assert cfg["refinement_passes"] == 2 # same as inverted
|
|
assert cfg["reflection_strength"] == 1.25 # tempered for CoT coherence
|
|
assert cfg["project_embeddings"] is True
|
|
assert cfg["embed_regularization"] == 0.50 # conservative cascade limit
|
|
assert cfg["activation_steering"] is True # residual cleanup hooks
|
|
assert cfg["steering_strength"] == 0.15 # light residual correction
|
|
assert cfg["expert_transplant"] is True
|
|
assert cfg["transplant_blend"] == 0.10 # gentle nudge, not overwrite
|
|
assert cfg["use_jailbreak_contrast"] is True
|
|
assert cfg["attention_head_surgery"] is True
|
|
assert cfg["layer_adaptive_strength"] is True # per-layer scaling
|
|
|
|
def test_nuclear_pipeline_init(self):
|
|
"""Pipeline initialized with nuclear method should have all flags set."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
|
|
assert pipeline.invert_refusal is True
|
|
assert pipeline.reflection_strength == 1.25
|
|
assert pipeline.embed_regularization == 0.50
|
|
assert pipeline.transplant_blend == 0.10
|
|
assert pipeline.project_embeddings is True
|
|
assert pipeline.activation_steering is True # residual cleanup
|
|
assert pipeline.expert_transplant is True
|
|
assert pipeline.n_directions == 4
|
|
assert pipeline.refinement_passes == 2
|
|
assert pipeline.layer_adaptive_strength is True
|
|
|
|
def test_reflection_strength_configurable(self):
|
|
"""reflection_strength should be explicitly overridable."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test", method="inverted", reflection_strength=3.0,
|
|
)
|
|
assert pipeline.reflection_strength == 3.0
|
|
|
|
def test_inverted_default_strength_is_2(self):
|
|
"""Inverted method should default to reflection_strength=2.0."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="inverted")
|
|
assert pipeline.reflection_strength == 2.0
|
|
|
|
def test_boosted_reflection_math(self):
|
|
"""2.5x reflection should produce stronger negation than 2x."""
|
|
hidden = 16
|
|
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.o_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
d = torch.randn(hidden, 1)
|
|
d = d / d.norm()
|
|
|
|
# 2x reflection
|
|
module_2x = Wrapper()
|
|
torch.manual_seed(42)
|
|
module_2x.o_proj.weight.data = torch.randn(32, hidden)
|
|
orig = module_2x.o_proj.weight.data.clone()
|
|
AbliterationPipeline._project_out_advanced(
|
|
module_2x, d, ["o_proj"], regularization=-1.0, # scale=2.0
|
|
)
|
|
proj_2x = (module_2x.o_proj.weight.data @ d).squeeze()
|
|
|
|
# 2.5x reflection
|
|
module_25x = Wrapper()
|
|
module_25x.o_proj.weight.data = orig.clone()
|
|
AbliterationPipeline._project_out_advanced(
|
|
module_25x, d, ["o_proj"], regularization=-1.5, # scale=2.5
|
|
)
|
|
proj_25x = (module_25x.o_proj.weight.data @ d).squeeze()
|
|
|
|
# 2.5x should be 25% stronger negation than 2x
|
|
assert proj_25x.norm() > proj_2x.norm(), (
|
|
"2.5x reflection should produce stronger (more negative) projection than 2x"
|
|
)
|
|
|
|
def test_activation_steering_hook(self):
|
|
"""Steering hooks should subtract refusal direction from hidden states."""
|
|
hidden = 8
|
|
|
|
class FakeLayer(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
layer = FakeLayer()
|
|
layers = torch.nn.ModuleList([layer])
|
|
|
|
# Explicitly enable steering (nuclear preset has it off by default)
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test", method="inverted", activation_steering=True,
|
|
steering_strength=0.5,
|
|
)
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
|
|
d = torch.randn(hidden)
|
|
d = d / d.norm()
|
|
pipeline.refusal_directions = {0: d}
|
|
pipeline._strong_layers = [0]
|
|
|
|
n_hooks = pipeline._install_activation_steering(layers)
|
|
assert n_hooks == 1
|
|
assert len(pipeline._steering_hooks) == 1
|
|
|
|
# Create a hidden state with strong refusal component
|
|
batch = torch.randn(1, 4, hidden)
|
|
refusal_component = 5.0 * d.unsqueeze(0).unsqueeze(0).expand_as(batch)
|
|
input_hidden = batch + refusal_component
|
|
|
|
# Run through the layer (hook should fire)
|
|
output = layer(input_hidden)
|
|
|
|
# The refusal component should be reduced
|
|
proj_before = torch.einsum("bsh,h->bs", input_hidden, d).abs().mean()
|
|
proj_after = torch.einsum("bsh,h->bs", output, d).abs().mean()
|
|
assert proj_after < proj_before, (
|
|
f"Steering should reduce refusal projection: before={proj_before:.3f}, after={proj_after:.3f}"
|
|
)
|
|
|
|
# Cleanup
|
|
for hook in pipeline._steering_hooks:
|
|
hook.remove()
|
|
|
|
def test_expert_transplant(self):
|
|
"""Expert transplant should overwrite safety expert weights with capability average."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
class FakeLayer(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.self_attn = torch.nn.Module()
|
|
self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
self.mlp = FakeMoE()
|
|
|
|
layer = FakeLayer()
|
|
layers = torch.nn.ModuleList([layer])
|
|
torch.manual_seed(42)
|
|
for p in layer.parameters():
|
|
p.data = torch.randn_like(p.data)
|
|
|
|
# Save original safety expert weight
|
|
orig_safety0 = layer.mlp.experts[0].down_proj.weight.data.clone()
|
|
# Save capability expert weights for computing expected mean
|
|
# With top-third classification (n_experts // 3 = 1), only expert 0
|
|
# is safety; experts 1, 2, 3 are all capability.
|
|
cap1 = layer.mlp.experts[1].down_proj.weight.data.clone()
|
|
cap2 = layer.mlp.experts[2].down_proj.weight.data.clone()
|
|
cap3 = layer.mlp.experts[3].down_proj.weight.data.clone()
|
|
expected_mean = (cap1 + cap2 + cap3) / 3.0
|
|
|
|
import obliteratus.abliterate as abl_module
|
|
from obliteratus.models.loader import ModelHandle
|
|
from transformers import GPT2Config
|
|
|
|
config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
|
|
model = MagicMock()
|
|
model.parameters.return_value = iter([torch.zeros(1)])
|
|
handle = ModelHandle(model=model, tokenizer=MagicMock(), config=config, model_name="test", task="causal_lm")
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
pipeline._strong_layers = [0]
|
|
# Experts 0,1 are safety (high affinity), 2,3 are capability
|
|
pipeline._expert_safety_scores = {
|
|
0: [(0, 5.0), (1, 3.0), (2, -1.0), (3, -3.0)]
|
|
}
|
|
|
|
orig_get_ffn = abl_module.get_ffn_module
|
|
abl_module.get_ffn_module = lambda lay, a: lay.mlp
|
|
try:
|
|
count = pipeline._transplant_expert_weights(layers)
|
|
finally:
|
|
abl_module.get_ffn_module = orig_get_ffn
|
|
|
|
assert count >= 1, f"Should blend at least 1 weight (top-third safety expert), got {count}"
|
|
|
|
# Safety expert 0 should be a 10% blend toward capability mean
|
|
# (nuclear default transplant_blend=0.10)
|
|
# new = 0.90 * original + 0.10 * capability_mean
|
|
blend = pipeline.transplant_blend # 0.10
|
|
expected_blend = (1.0 - blend) * orig_safety0 + blend * expected_mean
|
|
transplanted = layer.mlp.experts[0].down_proj.weight.data
|
|
assert torch.allclose(transplanted, expected_blend, atol=1e-4), (
|
|
f"Safety expert weight should be {blend:.0%} blended toward capability mean"
|
|
)
|
|
|
|
# Capability expert 2 should be unchanged
|
|
assert torch.allclose(layer.mlp.experts[2].down_proj.weight.data, cap2, atol=1e-6), (
|
|
"Capability expert should be unchanged"
|
|
)
|
|
|
|
def test_gather_state_dict_raises_on_missing_offload(self):
|
|
"""Should raise RuntimeError (not silently corrupt) when offload dir is missing."""
|
|
from obliteratus.models.loader import ModelHandle
|
|
from transformers import GPT2Config
|
|
|
|
config = GPT2Config(n_embd=8, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
|
|
|
|
# Create a fake model whose state_dict returns a meta tensor
|
|
fake_model = MagicMock()
|
|
meta_tensor = torch.empty(4, 8, device="meta")
|
|
fake_model.state_dict.return_value = {"layer.weight": meta_tensor}
|
|
|
|
handle = ModelHandle(
|
|
model=fake_model, tokenizer=MagicMock(), config=config,
|
|
model_name="test", task="causal_lm",
|
|
)
|
|
handle._offload_dir = "/nonexistent/path"
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
|
|
with pytest.raises(RuntimeError, match="bricked checkpoint"):
|
|
pipeline._gather_state_dict()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Knee detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestKneeDetection:
|
|
def test_empty_input(self):
|
|
result = AbliterationPipeline._select_layers_knee([])
|
|
assert result == []
|
|
|
|
def test_two_layers(self):
|
|
result = AbliterationPipeline._select_layers_knee([(0, 5.0), (1, 3.0)])
|
|
assert set(result) == {0, 1}
|
|
|
|
def test_clear_knee(self):
|
|
"""Layers with a sharp dropoff should be separated by knee detection."""
|
|
sorted_layers = [
|
|
(14, 10.0), (15, 9.5), (13, 9.0), # strong cluster
|
|
(16, 2.0), (12, 1.5), (17, 1.0), (11, 0.5), (18, 0.2), (10, 0.1),
|
|
]
|
|
result = AbliterationPipeline._select_layers_knee(sorted_layers)
|
|
# Should select the strong cluster (layers 14, 15, 13) and exclude weak ones
|
|
assert 14 in result
|
|
assert 15 in result
|
|
assert 13 in result
|
|
assert len(result) <= 5 # shouldn't select all 9
|
|
|
|
def test_minimum_threshold_filters_noise(self):
|
|
"""Layers below 10% of max should be filtered out."""
|
|
sorted_layers = [(0, 10.0), (1, 0.5)] # 0.5 is 5% of 10
|
|
result = AbliterationPipeline._select_layers_knee(sorted_layers)
|
|
# Layer 1 is below 10% threshold
|
|
assert 0 in result
|
|
|
|
def test_all_equal_norms(self):
|
|
"""When all norms are equal, should select all (or most)."""
|
|
sorted_layers = [(i, 5.0) for i in range(5)]
|
|
result = AbliterationPipeline._select_layers_knee(sorted_layers)
|
|
assert len(result) >= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Activation collection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestActivationCollection:
|
|
def test_collect_activations(self, handle):
|
|
"""Test that activation collection returns correct structure."""
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
|
|
pipeline = AbliterationPipeline(model_name="test")
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
|
|
layers = get_layer_modules(handle)
|
|
prompts = ["Hello world", "Test prompt"]
|
|
|
|
handle.tokenizer.return_value = {
|
|
"input_ids": torch.randint(0, 1000, (1, 5)),
|
|
"attention_mask": torch.ones(1, 5, dtype=torch.long),
|
|
}
|
|
|
|
activations = pipeline._collect_activations(layers, prompts, "test")
|
|
|
|
assert len(activations) == len(layers)
|
|
for idx in range(len(layers)):
|
|
assert len(activations[idx]) == len(prompts)
|
|
for act in activations[idx]:
|
|
assert act.device == torch.device("cpu")
|
|
assert act.shape[-1] == handle.hidden_size
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Distill: single direction (basic method)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDistillBasic:
|
|
def test_single_direction(self, handle):
|
|
"""Basic method: single refusal direction via difference-in-means."""
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test",
|
|
method="basic",
|
|
harmful_prompts=["bad prompt"],
|
|
harmless_prompts=["good prompt"],
|
|
)
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
_make_varied_tokenizer(handle)
|
|
|
|
pipeline._probe()
|
|
pipeline._distill()
|
|
|
|
n_layers = len(get_layer_modules(handle))
|
|
assert len(pipeline.refusal_directions) == n_layers
|
|
for idx, direction in pipeline.refusal_directions.items():
|
|
assert abs(direction.norm().item() - 1.0) < 1e-4
|
|
# Single direction: subspace should be (1, hidden_dim)
|
|
assert pipeline.refusal_subspaces[idx].shape[0] == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Distill: multi-direction SVD (advanced/aggressive method)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDistillSVD:
|
|
def test_multi_direction_svd(self, handle):
|
|
"""Advanced method: SVD extracts multiple refusal directions.
|
|
|
|
Note: on small models (hidden_size < 2048 or < 2B params), n_directions
|
|
is automatically capped to 2 to prevent over-ablation. The test model
|
|
(hidden_size=64, 4 layers) triggers this safeguard.
|
|
"""
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test",
|
|
method="advanced",
|
|
harmful_prompts=["bad1", "bad2", "bad3", "bad4", "bad5"],
|
|
harmless_prompts=["good1", "good2", "good3", "good4", "good5"],
|
|
)
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
_make_varied_tokenizer(handle)
|
|
|
|
pipeline._probe()
|
|
pipeline._distill()
|
|
|
|
n_layers = len(get_layer_modules(handle))
|
|
assert len(pipeline.refusal_subspaces) == n_layers
|
|
# Small-model cap: n_directions capped to 2 for tiny test model
|
|
expected_dirs = min(2, pipeline.n_directions, 5, handle.hidden_size)
|
|
for idx, subspace in pipeline.refusal_subspaces.items():
|
|
assert subspace.shape[0] == expected_dirs
|
|
assert subspace.shape[1] == handle.hidden_size
|
|
|
|
# Primary direction should still be a unit vector
|
|
for idx, direction in pipeline.refusal_directions.items():
|
|
assert abs(direction.norm().item() - 1.0) < 1e-4
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Full pipeline: excise with different methods
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestExcise:
|
|
def test_excise_basic(self, handle):
|
|
"""Basic method should modify weights."""
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test",
|
|
method="basic",
|
|
harmful_prompts=["bad prompt"],
|
|
harmless_prompts=["good prompt"],
|
|
)
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
_make_varied_tokenizer(handle)
|
|
|
|
layers = get_layer_modules(handle)
|
|
original_weights = {}
|
|
for idx in range(len(layers)):
|
|
for name, param in layers[idx].named_parameters():
|
|
original_weights[(idx, name)] = param.data.clone()
|
|
|
|
pipeline._probe()
|
|
pipeline._distill()
|
|
pipeline._excise()
|
|
|
|
any_changed = False
|
|
for idx in range(len(layers)):
|
|
for name, param in layers[idx].named_parameters():
|
|
if not torch.allclose(original_weights[(idx, name)], param.data, atol=1e-6):
|
|
any_changed = True
|
|
break
|
|
|
|
assert any_changed, "Excise should modify at least some weights"
|
|
|
|
def test_excise_advanced_norm_preserving(self, handle):
|
|
"""Advanced method with norm preservation should maintain weight norms."""
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test",
|
|
method="advanced",
|
|
harmful_prompts=["bad prompt"],
|
|
harmless_prompts=["good prompt"],
|
|
)
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
_make_varied_tokenizer(handle)
|
|
|
|
get_layer_modules(handle)
|
|
|
|
pipeline._probe()
|
|
pipeline._distill()
|
|
pipeline._excise()
|
|
|
|
# Weights should have been modified (advanced uses _project_out_advanced)
|
|
assert len(pipeline._strong_layers) > 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Rebirth (save)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRebirth:
|
|
def test_rebirth_saves_metadata(self, handle, tmp_path):
|
|
"""Rebirth should save model and comprehensive metadata JSON."""
|
|
pipeline = AbliterationPipeline(
|
|
model_name="test-model",
|
|
output_dir=str(tmp_path / "output"),
|
|
method="advanced",
|
|
)
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
pipeline._strong_layers = [0]
|
|
pipeline._quality_metrics = {"perplexity": 8.5, "coherence": 1.0}
|
|
|
|
handle.model.save_pretrained = MagicMock()
|
|
handle.tokenizer.save_pretrained = MagicMock()
|
|
|
|
result_path = pipeline._rebirth()
|
|
|
|
assert result_path == tmp_path / "output"
|
|
assert (result_path / "abliteration_metadata.json").exists()
|
|
|
|
metadata = json.loads((result_path / "abliteration_metadata.json").read_text())
|
|
assert metadata["source_model"] == "test-model"
|
|
assert metadata["technique"] == "refusal_direction_ablation"
|
|
assert metadata["method"] == "advanced"
|
|
assert metadata["strong_layers"] == [0]
|
|
assert "method_config" in metadata
|
|
assert metadata["method_config"]["n_directions"] == METHODS["advanced"]["n_directions"]
|
|
assert metadata["method_config"]["norm_preserve"] is True
|
|
assert "references" in metadata
|
|
assert len(metadata["references"]) >= 3
|
|
assert "quality_metrics" in metadata
|
|
assert metadata["quality_metrics"]["perplexity"] == 8.5
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCLI:
|
|
def test_abliterate_parser_with_method(self):
|
|
"""Test that the abliterate subcommand parses method correctly."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
subparsers = parser.add_subparsers(dest="command")
|
|
abl_parser = subparsers.add_parser("abliterate")
|
|
abl_parser.add_argument("model", type=str)
|
|
abl_parser.add_argument("--output-dir", type=str, default=None)
|
|
abl_parser.add_argument("--device", type=str, default="auto")
|
|
abl_parser.add_argument("--dtype", type=str, default="float16")
|
|
abl_parser.add_argument("--method", type=str, default="advanced",
|
|
choices=["basic", "advanced", "aggressive"])
|
|
abl_parser.add_argument("--n-directions", type=int, default=None)
|
|
abl_parser.add_argument("--regularization", type=float, default=None)
|
|
abl_parser.add_argument("--refinement-passes", type=int, default=None)
|
|
|
|
args = parser.parse_args(["abliterate", "gpt2", "--method", "aggressive", "--n-directions", "6"])
|
|
assert args.command == "abliterate"
|
|
assert args.model == "gpt2"
|
|
assert args.method == "aggressive"
|
|
assert args.n_directions == 6
|
|
assert args.dtype == "float16"
|
|
|
|
def test_default_method(self):
|
|
"""Default method should be advanced."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
subparsers = parser.add_subparsers(dest="command")
|
|
abl_parser = subparsers.add_parser("abliterate")
|
|
abl_parser.add_argument("model", type=str)
|
|
abl_parser.add_argument("--method", type=str, default="advanced")
|
|
|
|
args = parser.parse_args(["abliterate", "gpt2"])
|
|
assert args.method == "advanced"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Expert-Granular Abliteration (EGA)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestFindRouterModule:
|
|
"""Test _find_router_module static method."""
|
|
|
|
def test_finds_gate(self):
|
|
"""Should find a router named 'gate'."""
|
|
hidden = 16
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, 4, bias=False)
|
|
self.experts = torch.nn.ModuleList()
|
|
|
|
moe = FakeMoE()
|
|
router = AbliterationPipeline._find_router_module(moe)
|
|
assert router is moe.gate
|
|
|
|
def test_finds_router(self):
|
|
"""Should find a router named 'router'."""
|
|
hidden = 16
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.router = torch.nn.Linear(hidden, 4, bias=False)
|
|
self.experts = torch.nn.ModuleList()
|
|
|
|
moe = FakeMoE()
|
|
router = AbliterationPipeline._find_router_module(moe)
|
|
assert router is moe.router
|
|
|
|
def test_auto_detects_unknown_router(self):
|
|
"""Should auto-detect a router with unusual name via heuristic."""
|
|
hidden = 16
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.moe_gate_proj = torch.nn.Linear(hidden, 4, bias=False)
|
|
self.experts = torch.nn.ModuleList()
|
|
|
|
moe = FakeMoE()
|
|
router = AbliterationPipeline._find_router_module(moe)
|
|
assert router is moe.moe_gate_proj
|
|
|
|
def test_returns_none_no_router(self):
|
|
"""Should return None when no router is found."""
|
|
class NoRouter(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(16, 16)
|
|
|
|
mod = NoRouter()
|
|
assert AbliterationPipeline._find_router_module(mod) is None
|
|
|
|
|
|
class TestRouterProfilingHooks:
|
|
"""Test _install_router_profiling_hooks."""
|
|
|
|
def _make_moe_pipeline_and_layers(self, hidden=16, n_experts=4):
|
|
"""Create a pipeline with a fake MoE model for router profiling tests."""
|
|
from obliteratus.models.loader import ModelHandle
|
|
from transformers import GPT2Config
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class FakeLayer(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.self_attn = torch.nn.Module()
|
|
self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False)
|
|
self.mlp = FakeMoE()
|
|
|
|
def forward(self, x):
|
|
return (x,)
|
|
|
|
config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64)
|
|
model = MagicMock()
|
|
model.parameters.return_value = iter([torch.zeros(1)])
|
|
handle = ModelHandle(model=model, tokenizer=MagicMock(), config=config, model_name="test", task="causal_lm")
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline.handle = handle
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
|
|
layer = FakeLayer()
|
|
layers = torch.nn.ModuleList([layer])
|
|
|
|
# Monkey-patch get_ffn_module
|
|
import obliteratus.abliterate as abl_module
|
|
orig_get_ffn = abl_module.get_ffn_module
|
|
abl_module.get_ffn_module = lambda lay, a: lay.mlp
|
|
|
|
return pipeline, layers, layer, abl_module, orig_get_ffn
|
|
|
|
def test_hooks_installed(self):
|
|
"""Should install hooks on MoE router modules."""
|
|
pipeline, layers, layer, abl_module, orig_get_ffn = self._make_moe_pipeline_and_layers()
|
|
try:
|
|
hooks = pipeline._install_router_profiling_hooks(layers)
|
|
assert len(hooks) == 1
|
|
assert 0 in pipeline._routing_harmful
|
|
assert 0 in pipeline._routing_harmless
|
|
finally:
|
|
for h in hooks:
|
|
h.remove()
|
|
abl_module.get_ffn_module = orig_get_ffn
|
|
|
|
def test_hooks_record_logits(self):
|
|
"""Hooks should record router logits during forward passes."""
|
|
pipeline, layers, layer, abl_module, orig_get_ffn = self._make_moe_pipeline_and_layers()
|
|
try:
|
|
hooks = pipeline._install_router_profiling_hooks(layers)
|
|
|
|
# Simulate harmful forward pass
|
|
pipeline._routing_is_harmful = True
|
|
x = torch.randn(1, 5, 16)
|
|
layer.mlp.gate(x) # triggers hook
|
|
|
|
assert len(pipeline._routing_harmful[0]) == 1
|
|
assert pipeline._routing_harmful[0][0].shape[0] == 4 # n_experts
|
|
|
|
# Simulate harmless forward pass
|
|
pipeline._routing_is_harmful = False
|
|
layer.mlp.gate(x)
|
|
|
|
assert len(pipeline._routing_harmless[0]) == 1
|
|
finally:
|
|
for h in hooks:
|
|
h.remove()
|
|
abl_module.get_ffn_module = orig_get_ffn
|
|
|
|
def test_no_handle_returns_empty(self):
|
|
"""Should return empty list when handle is None."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline.handle = None
|
|
hooks = pipeline._install_router_profiling_hooks(torch.nn.ModuleList())
|
|
assert hooks == []
|
|
|
|
|
|
class TestComputeExpertGranularDirections:
|
|
"""Test _compute_expert_granular_directions."""
|
|
|
|
def test_computes_per_expert_directions(self):
|
|
"""Should compute per-expert refusal directions from routing data."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
pipeline._strong_layers = [0]
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Simulate router logits: expert 0 favored for harmful, expert 3 for harmless
|
|
h_logits = []
|
|
s_logits = []
|
|
for _ in range(10):
|
|
hl = torch.randn(n_experts)
|
|
hl[0] += 2.0 # bias expert 0 for harmful
|
|
h_logits.append(hl)
|
|
sl = torch.randn(n_experts)
|
|
sl[3] += 2.0 # bias expert 3 for harmless
|
|
s_logits.append(sl)
|
|
|
|
pipeline._routing_harmful = {0: h_logits}
|
|
pipeline._routing_harmless = {0: s_logits}
|
|
|
|
# Simulate per-prompt activations with harmful/harmless separation
|
|
refusal_dir = torch.randn(hidden)
|
|
refusal_dir = refusal_dir / refusal_dir.norm()
|
|
|
|
h_acts = [torch.randn(hidden) + 1.5 * refusal_dir for _ in range(10)]
|
|
s_acts = [torch.randn(hidden) - 1.5 * refusal_dir for _ in range(10)]
|
|
pipeline._harmful_acts = {0: h_acts}
|
|
pipeline._harmless_acts = {0: s_acts}
|
|
|
|
pipeline._compute_expert_granular_directions()
|
|
|
|
# Should have computed expert directions for layer 0
|
|
assert 0 in pipeline._expert_directions
|
|
assert len(pipeline._expert_directions[0]) > 0
|
|
|
|
# Should have dynamic safety scores
|
|
assert 0 in pipeline._expert_safety_scores
|
|
scores = pipeline._expert_safety_scores[0]
|
|
assert len(scores) == n_experts
|
|
# Expert 0 should have higher safety score (more activated for harmful)
|
|
expert_0_score = next(s for eid, s in scores if eid == 0)
|
|
expert_3_score = next(s for eid, s in scores if eid == 3)
|
|
assert expert_0_score > expert_3_score, (
|
|
f"Expert 0 should have higher safety score: {expert_0_score} vs {expert_3_score}"
|
|
)
|
|
|
|
def test_directions_are_unit_vectors(self):
|
|
"""Per-expert directions should be unit normalized."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._strong_layers = [0]
|
|
|
|
torch.manual_seed(42)
|
|
h_logits = [torch.randn(n_experts) for _ in range(10)]
|
|
s_logits = [torch.randn(n_experts) for _ in range(10)]
|
|
pipeline._routing_harmful = {0: h_logits}
|
|
pipeline._routing_harmless = {0: s_logits}
|
|
pipeline._harmful_acts = {0: [torch.randn(hidden) + torch.ones(hidden) for _ in range(10)]}
|
|
pipeline._harmless_acts = {0: [torch.randn(hidden) - torch.ones(hidden) for _ in range(10)]}
|
|
|
|
pipeline._compute_expert_granular_directions()
|
|
|
|
if 0 in pipeline._expert_directions:
|
|
for ei, d in pipeline._expert_directions[0].items():
|
|
assert abs(d.norm().item() - 1.0) < 1e-4, (
|
|
f"Expert {ei} direction norm={d.norm().item()}, expected 1.0"
|
|
)
|
|
|
|
def test_skips_when_no_routing_data(self):
|
|
"""Should skip gracefully when no routing data is available."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._routing_harmful = {}
|
|
pipeline._routing_harmless = {}
|
|
|
|
pipeline._compute_expert_granular_directions()
|
|
|
|
assert len(pipeline._expert_directions) == 0
|
|
|
|
def test_skips_expert_with_low_routing_weight(self):
|
|
"""Experts with insufficient routing weight should not get directions."""
|
|
hidden = 16
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._strong_layers = [0]
|
|
|
|
# Create routing logits where expert 3 is never selected (very low)
|
|
h_logits = []
|
|
s_logits = []
|
|
for _ in range(3):
|
|
hl = torch.tensor([5.0, 5.0, 5.0, -100.0]) # expert 3 never routed
|
|
h_logits.append(hl)
|
|
sl = torch.tensor([5.0, 5.0, 5.0, -100.0])
|
|
s_logits.append(sl)
|
|
|
|
pipeline._routing_harmful = {0: h_logits}
|
|
pipeline._routing_harmless = {0: s_logits}
|
|
|
|
torch.manual_seed(42)
|
|
pipeline._harmful_acts = {0: [torch.randn(hidden) for _ in range(3)]}
|
|
pipeline._harmless_acts = {0: [torch.randn(hidden) for _ in range(3)]}
|
|
|
|
pipeline._compute_expert_granular_directions()
|
|
|
|
# Expert 3 should NOT have a direction (routing weight too low)
|
|
if 0 in pipeline._expert_directions:
|
|
assert 3 not in pipeline._expert_directions[0], (
|
|
"Expert with near-zero routing weight should not get a direction"
|
|
)
|
|
|
|
|
|
class TestProjectMoEExpertsGranular:
|
|
"""Test _project_moe_experts_granular (ModuleList path)."""
|
|
|
|
def _make_direction(self, hidden_dim=16):
|
|
d = torch.randn(hidden_dim, 1)
|
|
return d / d.norm()
|
|
|
|
def test_per_expert_directions_applied(self):
|
|
"""Each expert should use its own direction when available."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
moe = FakeMoE()
|
|
torch.manual_seed(42)
|
|
for p in moe.parameters():
|
|
p.data = torch.randn_like(p.data)
|
|
|
|
shared_dir = self._make_direction(hidden)
|
|
|
|
# Create distinct per-expert directions
|
|
expert_dirs = {}
|
|
for ei in range(n_experts):
|
|
d = torch.randn(hidden)
|
|
d = d / d.norm()
|
|
expert_dirs[ei] = d
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._expert_directions = {0: expert_dirs}
|
|
|
|
# Save originals
|
|
orig_weights = {
|
|
ei: moe.experts[ei].down_proj.weight.data.clone()
|
|
for ei in range(n_experts)
|
|
}
|
|
|
|
count = pipeline._project_moe_experts_granular(
|
|
moe, shared_dir, layer_idx=0,
|
|
)
|
|
|
|
assert count > 0, "Should project some weights"
|
|
|
|
# All experts should be modified
|
|
for ei in range(n_experts):
|
|
assert not torch.allclose(
|
|
moe.experts[ei].down_proj.weight.data, orig_weights[ei]
|
|
), f"Expert {ei} should be modified"
|
|
|
|
def test_falls_back_to_shared_direction(self):
|
|
"""Experts without per-expert direction should use shared direction."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
moe = FakeMoE()
|
|
torch.manual_seed(42)
|
|
for p in moe.parameters():
|
|
p.data = torch.randn_like(p.data)
|
|
|
|
shared_dir = self._make_direction(hidden)
|
|
|
|
# Only expert 0 has a per-expert direction
|
|
expert_dirs = {0: torch.randn(hidden).div_(torch.randn(hidden).norm())}
|
|
expert_dirs[0] = expert_dirs[0] / expert_dirs[0].norm()
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._expert_directions = {0: expert_dirs}
|
|
|
|
orig_e1 = moe.experts[1].down_proj.weight.data.clone()
|
|
|
|
pipeline._project_moe_experts_granular(
|
|
moe, shared_dir, layer_idx=0,
|
|
)
|
|
|
|
# Experts 1,2,3 should be modified (using shared direction)
|
|
assert not torch.allclose(moe.experts[1].down_proj.weight.data, orig_e1), \
|
|
"Expert 1 should use shared direction fallback"
|
|
|
|
def test_router_uses_shared_direction(self):
|
|
"""Router should always use the shared direction, not per-expert."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
moe = FakeMoE()
|
|
shared_dir = self._make_direction(hidden)
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._expert_directions = {0: {0: torch.randn(hidden)}}
|
|
|
|
orig_gate = moe.gate.weight.data.clone()
|
|
|
|
pipeline._project_moe_experts_granular(moe, shared_dir, layer_idx=0)
|
|
|
|
# Gate should be projected
|
|
assert not torch.allclose(moe.gate.weight.data, orig_gate), \
|
|
"Router should be projected with shared direction"
|
|
|
|
# Gate's projection onto shared direction should be near zero
|
|
proj = (moe.gate.weight.data @ shared_dir).norm().item()
|
|
assert proj < 1e-4, f"Router should have shared dir removed, proj={proj}"
|
|
|
|
def test_shared_expert_uses_shared_direction(self):
|
|
"""Shared expert should always use the shared direction."""
|
|
hidden = 16
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, 2, bias=False)
|
|
self.shared_expert = torch.nn.Module()
|
|
self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(2)])
|
|
|
|
moe = FakeMoE()
|
|
shared_dir = self._make_direction(hidden)
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._expert_directions = {0: {0: torch.randn(hidden)}}
|
|
|
|
orig_shared = moe.shared_expert.down_proj.weight.data.clone()
|
|
|
|
pipeline._project_moe_experts_granular(moe, shared_dir, layer_idx=0)
|
|
|
|
assert not torch.allclose(moe.shared_expert.down_proj.weight.data, orig_shared), \
|
|
"Shared expert should be projected"
|
|
|
|
|
|
class TestProjectFused3DGranular:
|
|
"""Test _project_fused_3d_granular for fused 3D expert tensors."""
|
|
|
|
def test_per_expert_directions_on_fused(self):
|
|
"""Each expert slice should use its own direction."""
|
|
hidden = 16
|
|
intermediate = 32
|
|
n_experts = 4
|
|
|
|
class FusedExperts(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
|
|
|
|
container = FusedExperts()
|
|
torch.manual_seed(42)
|
|
|
|
shared_dir = torch.randn(hidden, 1)
|
|
shared_dir = shared_dir / shared_dir.norm()
|
|
|
|
# Per-expert directions
|
|
expert_dirs = {}
|
|
for ei in range(n_experts):
|
|
d = torch.randn(hidden)
|
|
d = d / d.norm()
|
|
expert_dirs[ei] = d
|
|
|
|
orig_data = container.down_proj.data.clone()
|
|
|
|
count = AbliterationPipeline._project_fused_3d_granular(
|
|
container, shared_dir, expert_dirs, ["down_proj"],
|
|
norm_preserve=False, scale=1.0,
|
|
)
|
|
|
|
assert count == n_experts, f"Should project {n_experts} experts, got {count}"
|
|
|
|
# Each expert should be modified
|
|
for ei in range(n_experts):
|
|
assert not torch.allclose(
|
|
container.down_proj.data[ei], orig_data[ei]
|
|
), f"Expert {ei} should be modified"
|
|
|
|
def test_fallback_to_shared_on_fused(self):
|
|
"""Experts without per-expert direction should use shared direction."""
|
|
hidden = 16
|
|
intermediate = 32
|
|
n_experts = 4
|
|
|
|
class FusedExperts(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
|
|
|
|
container = FusedExperts()
|
|
torch.manual_seed(42)
|
|
|
|
shared_dir = torch.randn(hidden, 1)
|
|
shared_dir = shared_dir / shared_dir.norm()
|
|
|
|
# Only expert 0 has a direction
|
|
expert_dirs = {0: torch.randn(hidden).div_(1.0)}
|
|
expert_dirs[0] = expert_dirs[0] / expert_dirs[0].norm()
|
|
|
|
orig_data = container.down_proj.data.clone()
|
|
|
|
count = AbliterationPipeline._project_fused_3d_granular(
|
|
container, shared_dir, expert_dirs, ["down_proj"],
|
|
norm_preserve=False, scale=1.0,
|
|
)
|
|
|
|
assert count == n_experts
|
|
# All experts should be modified (experts 1-3 use shared dir)
|
|
for ei in range(n_experts):
|
|
assert not torch.allclose(
|
|
container.down_proj.data[ei], orig_data[ei]
|
|
), f"Expert {ei} should be modified"
|
|
|
|
def test_norm_preserve_on_fused(self):
|
|
"""Fused 3D with norm_preserve should maintain per-expert norms."""
|
|
hidden = 16
|
|
intermediate = 32
|
|
n_experts = 4
|
|
|
|
class FusedExperts(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden))
|
|
|
|
container = FusedExperts()
|
|
torch.manual_seed(42)
|
|
|
|
shared_dir = torch.randn(hidden, 1)
|
|
shared_dir = shared_dir / shared_dir.norm()
|
|
|
|
expert_dirs = {}
|
|
for ei in range(n_experts):
|
|
d = torch.randn(hidden)
|
|
expert_dirs[ei] = d / d.norm()
|
|
|
|
orig_norms = [container.down_proj.data[i].norm().item() for i in range(n_experts)]
|
|
|
|
AbliterationPipeline._project_fused_3d_granular(
|
|
container, shared_dir, expert_dirs, ["down_proj"],
|
|
norm_preserve=True, scale=1.0,
|
|
)
|
|
|
|
for i in range(n_experts):
|
|
new_norm = container.down_proj.data[i].norm().item()
|
|
assert abs(orig_norms[i] - new_norm) < 1e-3, (
|
|
f"Expert {i} norm not preserved: {orig_norms[i]:.4f} vs {new_norm:.4f}"
|
|
)
|
|
|
|
def test_skips_non_3d_params(self):
|
|
"""Should skip parameters that are not 3-dimensional."""
|
|
hidden = 16
|
|
|
|
class FlatExperts(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Parameter(torch.randn(32, hidden))
|
|
|
|
container = FlatExperts()
|
|
shared_dir = torch.randn(hidden, 1)
|
|
shared_dir = shared_dir / shared_dir.norm()
|
|
|
|
count = AbliterationPipeline._project_fused_3d_granular(
|
|
container, shared_dir, {}, ["down_proj"],
|
|
norm_preserve=False, scale=1.0,
|
|
)
|
|
assert count == 0
|
|
|
|
|
|
class TestEGAExciseIntegration:
|
|
"""Test that EGA integrates properly in the excise stage path."""
|
|
|
|
def test_ega_pipeline_flags(self):
|
|
"""Pipeline with surgical method should enable per_expert_directions."""
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
assert pipeline.per_expert_directions is True
|
|
|
|
def test_ega_only_on_primary_direction(self):
|
|
"""EGA should only apply for dir_idx==0, not higher SVD directions."""
|
|
# This is enforced by the `and dir_idx == 0` check in _excise
|
|
# We verify the code structure exists
|
|
from obliteratus.abliterate import AbliterationPipeline
|
|
import inspect
|
|
source = inspect.getsource(AbliterationPipeline._excise_inner)
|
|
assert "dir_idx == 0" in source, "EGA should only apply for primary direction"
|
|
assert "_project_moe_experts_granular" in source, "EGA method should be called in excise"
|
|
|
|
def test_ega_distill_integration(self):
|
|
"""EGA should be called during distill when per_expert_directions is enabled."""
|
|
from obliteratus.abliterate import AbliterationPipeline
|
|
import inspect
|
|
source = inspect.getsource(AbliterationPipeline._distill)
|
|
assert "_compute_expert_granular_directions" in source
|
|
assert "per_expert_directions" in source
|
|
|
|
def test_nuclear_method_enables_ega(self):
|
|
"""Nuclear method should also enable per_expert_directions."""
|
|
cfg = METHODS["nuclear"]
|
|
assert cfg["per_expert_directions"] is True
|
|
pipeline = AbliterationPipeline(model_name="test", method="nuclear")
|
|
assert pipeline.per_expert_directions is True
|
|
|
|
def test_basic_method_disables_ega(self):
|
|
"""Basic method should not enable per_expert_directions."""
|
|
cfg = METHODS["basic"]
|
|
assert cfg.get("per_expert_directions", False) is False
|
|
|
|
def test_inverted_method_enables_ega(self):
|
|
"""Inverted method should enable per_expert_directions."""
|
|
cfg = METHODS["inverted"]
|
|
assert cfg["per_expert_directions"] is True
|
|
|
|
def test_ega_with_routing_data_end_to_end(self):
|
|
"""End-to-end: EGA computes directions and granular projection modifies weights."""
|
|
hidden = 16
|
|
n_experts = 4
|
|
|
|
class FakeExpert(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.down_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
self.up_proj = torch.nn.Linear(hidden, 32, bias=False)
|
|
|
|
class FakeMoE(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = torch.nn.Linear(hidden, n_experts, bias=False)
|
|
self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)])
|
|
|
|
moe = FakeMoE()
|
|
torch.manual_seed(42)
|
|
for p in moe.parameters():
|
|
p.data = torch.randn_like(p.data)
|
|
|
|
pipeline = AbliterationPipeline(model_name="test", method="surgical")
|
|
pipeline._on_log = lambda m: None
|
|
pipeline._on_stage = lambda r: None
|
|
pipeline._strong_layers = [0]
|
|
|
|
# Simulate EGA routing data
|
|
h_logits = [torch.randn(n_experts) for _ in range(5)]
|
|
s_logits = [torch.randn(n_experts) for _ in range(5)]
|
|
pipeline._routing_harmful = {0: h_logits}
|
|
pipeline._routing_harmless = {0: s_logits}
|
|
|
|
# Simulate activations with clear separation
|
|
refusal_dir = torch.randn(hidden)
|
|
refusal_dir = refusal_dir / refusal_dir.norm()
|
|
pipeline._harmful_acts = {0: [torch.randn(hidden) + 2 * refusal_dir for _ in range(5)]}
|
|
pipeline._harmless_acts = {0: [torch.randn(hidden) - 2 * refusal_dir for _ in range(5)]}
|
|
|
|
# Step 1: compute EGA directions
|
|
pipeline._compute_expert_granular_directions()
|
|
assert 0 in pipeline._expert_directions
|
|
assert len(pipeline._expert_directions[0]) > 0
|
|
|
|
# Step 2: apply granular projection
|
|
shared_dir = torch.randn(hidden, 1)
|
|
shared_dir = shared_dir / shared_dir.norm()
|
|
|
|
orig_expert0 = moe.experts[0].down_proj.weight.data.clone()
|
|
|
|
count = pipeline._project_moe_experts_granular(
|
|
moe, shared_dir, layer_idx=0,
|
|
)
|
|
|
|
assert count > 0
|
|
assert not torch.allclose(moe.experts[0].down_proj.weight.data, orig_expert0), \
|
|
"Expert weights should be modified by EGA"
|