Files
OBLITERATUS/tests/test_strategies.py
2026-03-05 00:50:44 -08:00

180 lines
6.0 KiB
Python

"""Tests for ablation strategies using a small GPT-2 model."""
from __future__ import annotations
import pytest
import torch
from obliteratus.strategies.base import AblationSpec
from obliteratus.strategies.registry import STRATEGY_REGISTRY, get_strategy
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_dummy_handle():
"""Create a minimal ModelHandle with a tiny GPT-2 for testing (no network)."""
from unittest.mock import MagicMock
from transformers import GPT2Config, GPT2LMHeadModel
from obliteratus.models.loader import ModelHandle
config = GPT2Config(
vocab_size=1000,
n_positions=128,
n_embd=64,
n_layer=2,
n_head=2,
n_inner=256,
)
model = GPT2LMHeadModel(config)
model.eval()
# Strategy tests don't tokenize — use a simple mock
tokenizer = MagicMock()
tokenizer.pad_token = "<pad>"
tokenizer.eos_token = "<eos>"
handle = ModelHandle(
model=model,
tokenizer=tokenizer,
config=config,
model_name="gpt2-test",
task="causal_lm",
)
handle.snapshot()
return handle
@pytest.fixture
def handle():
return _make_dummy_handle()
# ---------------------------------------------------------------------------
# Registry tests
# ---------------------------------------------------------------------------
class TestRegistry:
def test_all_strategies_registered(self):
expected = {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"}
assert expected.issubset(set(STRATEGY_REGISTRY.keys()))
def test_get_strategy_returns_instance(self):
strat = get_strategy("layer_removal")
assert strat.name == "layer_removal"
def test_get_unknown_strategy_raises(self):
with pytest.raises(KeyError, match="Unknown strategy"):
get_strategy("nonexistent_strategy")
# ---------------------------------------------------------------------------
# Layer removal
# ---------------------------------------------------------------------------
class TestLayerRemoval:
def test_enumerate(self, handle):
strat = get_strategy("layer_removal")
specs = strat.enumerate(handle)
assert len(specs) == handle.num_layers
assert all(s.strategy_name == "layer_removal" for s in specs)
def test_apply_zeros_layer(self, handle):
strat = get_strategy("layer_removal")
specs = strat.enumerate(handle)
strat.apply(handle, specs[0])
from obliteratus.strategies.utils import get_layer_modules
layer = get_layer_modules(handle)[0]
for param in layer.parameters():
assert torch.all(param == 0), "Layer params should be zeroed after ablation"
def test_restore_after_ablation(self, handle):
strat = get_strategy("layer_removal")
specs = strat.enumerate(handle)
from obliteratus.strategies.utils import get_layer_modules
original_weight = get_layer_modules(handle)[0].attn.c_attn.weight.clone()
strat.apply(handle, specs[0])
handle.restore()
restored_weight = get_layer_modules(handle)[0].attn.c_attn.weight
assert torch.allclose(original_weight, restored_weight)
# ---------------------------------------------------------------------------
# Head pruning
# ---------------------------------------------------------------------------
class TestHeadPruning:
def test_enumerate(self, handle):
strat = get_strategy("head_pruning")
specs = strat.enumerate(handle)
assert len(specs) == handle.num_layers * handle.num_heads
def test_apply_zeros_head(self, handle):
strat = get_strategy("head_pruning")
spec = AblationSpec(
strategy_name="head_pruning",
component="layer_0_head_0",
description="test",
metadata={"layer_idx": 0, "head_idx": 0},
)
strat.apply(handle, spec)
from obliteratus.strategies.utils import get_layer_modules, get_attention_module
attn = get_attention_module(get_layer_modules(handle)[0], handle.architecture)
head_dim = handle.hidden_size // handle.num_heads
# GPT-2 uses c_attn (Conv1D), check output projection c_proj
if hasattr(attn, "c_proj"):
# Conv1D stores weight transposed
assert torch.all(attn.c_proj.weight[0:head_dim, :] == 0)
# ---------------------------------------------------------------------------
# FFN ablation
# ---------------------------------------------------------------------------
class TestFFNAblation:
def test_enumerate(self, handle):
strat = get_strategy("ffn_ablation")
specs = strat.enumerate(handle)
assert len(specs) == handle.num_layers
def test_apply_zeros_ffn(self, handle):
strat = get_strategy("ffn_ablation")
specs = strat.enumerate(handle)
strat.apply(handle, specs[0])
from obliteratus.strategies.utils import get_layer_modules, get_ffn_module
ffn = get_ffn_module(get_layer_modules(handle)[0], handle.architecture)
for param in ffn.parameters():
assert torch.all(param == 0)
# ---------------------------------------------------------------------------
# Embedding ablation
# ---------------------------------------------------------------------------
class TestEmbeddingAblation:
def test_enumerate(self, handle):
strat = get_strategy("embedding_ablation")
specs = strat.enumerate(handle)
assert len(specs) > 0
def test_apply_zeros_dims(self, handle):
strat = get_strategy("embedding_ablation")
spec = AblationSpec(
strategy_name="embedding_ablation",
component="embed_dims_0_4",
description="test",
metadata={"dim_start": 0, "dim_end": 4},
)
strat.apply(handle, spec)
from obliteratus.strategies.utils import get_embedding_module
emb = get_embedding_module(handle)
assert torch.all(emb.weight[:, 0:4] == 0)