mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-06-09 15:54:03 +02:00
180 lines
6.0 KiB
Python
180 lines
6.0 KiB
Python
"""Tests for ablation strategies using a small GPT-2 model."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from obliteratus.strategies.base import AblationSpec
|
|
from obliteratus.strategies.registry import STRATEGY_REGISTRY, get_strategy
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_dummy_handle():
|
|
"""Create a minimal ModelHandle with a tiny GPT-2 for testing (no network)."""
|
|
from unittest.mock import MagicMock
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
from obliteratus.models.loader import ModelHandle
|
|
|
|
config = GPT2Config(
|
|
vocab_size=1000,
|
|
n_positions=128,
|
|
n_embd=64,
|
|
n_layer=2,
|
|
n_head=2,
|
|
n_inner=256,
|
|
)
|
|
model = GPT2LMHeadModel(config)
|
|
model.eval()
|
|
|
|
# Strategy tests don't tokenize — use a simple mock
|
|
tokenizer = MagicMock()
|
|
tokenizer.pad_token = "<pad>"
|
|
tokenizer.eos_token = "<eos>"
|
|
|
|
handle = ModelHandle(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
config=config,
|
|
model_name="gpt2-test",
|
|
task="causal_lm",
|
|
)
|
|
handle.snapshot()
|
|
return handle
|
|
|
|
|
|
@pytest.fixture
|
|
def handle():
|
|
return _make_dummy_handle()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registry tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRegistry:
|
|
def test_all_strategies_registered(self):
|
|
expected = {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"}
|
|
assert expected.issubset(set(STRATEGY_REGISTRY.keys()))
|
|
|
|
def test_get_strategy_returns_instance(self):
|
|
strat = get_strategy("layer_removal")
|
|
assert strat.name == "layer_removal"
|
|
|
|
def test_get_unknown_strategy_raises(self):
|
|
with pytest.raises(KeyError, match="Unknown strategy"):
|
|
get_strategy("nonexistent_strategy")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Layer removal
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLayerRemoval:
|
|
def test_enumerate(self, handle):
|
|
strat = get_strategy("layer_removal")
|
|
specs = strat.enumerate(handle)
|
|
assert len(specs) == handle.num_layers
|
|
assert all(s.strategy_name == "layer_removal" for s in specs)
|
|
|
|
def test_apply_zeros_layer(self, handle):
|
|
strat = get_strategy("layer_removal")
|
|
specs = strat.enumerate(handle)
|
|
strat.apply(handle, specs[0])
|
|
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
layer = get_layer_modules(handle)[0]
|
|
for param in layer.parameters():
|
|
assert torch.all(param == 0), "Layer params should be zeroed after ablation"
|
|
|
|
def test_restore_after_ablation(self, handle):
|
|
strat = get_strategy("layer_removal")
|
|
specs = strat.enumerate(handle)
|
|
|
|
from obliteratus.strategies.utils import get_layer_modules
|
|
original_weight = get_layer_modules(handle)[0].attn.c_attn.weight.clone()
|
|
|
|
strat.apply(handle, specs[0])
|
|
handle.restore()
|
|
|
|
restored_weight = get_layer_modules(handle)[0].attn.c_attn.weight
|
|
assert torch.allclose(original_weight, restored_weight)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Head pruning
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestHeadPruning:
|
|
def test_enumerate(self, handle):
|
|
strat = get_strategy("head_pruning")
|
|
specs = strat.enumerate(handle)
|
|
assert len(specs) == handle.num_layers * handle.num_heads
|
|
|
|
def test_apply_zeros_head(self, handle):
|
|
strat = get_strategy("head_pruning")
|
|
spec = AblationSpec(
|
|
strategy_name="head_pruning",
|
|
component="layer_0_head_0",
|
|
description="test",
|
|
metadata={"layer_idx": 0, "head_idx": 0},
|
|
)
|
|
strat.apply(handle, spec)
|
|
|
|
from obliteratus.strategies.utils import get_layer_modules, get_attention_module
|
|
attn = get_attention_module(get_layer_modules(handle)[0], handle.architecture)
|
|
head_dim = handle.hidden_size // handle.num_heads
|
|
# GPT-2 uses c_attn (Conv1D), check output projection c_proj
|
|
if hasattr(attn, "c_proj"):
|
|
# Conv1D stores weight transposed
|
|
assert torch.all(attn.c_proj.weight[0:head_dim, :] == 0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FFN ablation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestFFNAblation:
|
|
def test_enumerate(self, handle):
|
|
strat = get_strategy("ffn_ablation")
|
|
specs = strat.enumerate(handle)
|
|
assert len(specs) == handle.num_layers
|
|
|
|
def test_apply_zeros_ffn(self, handle):
|
|
strat = get_strategy("ffn_ablation")
|
|
specs = strat.enumerate(handle)
|
|
strat.apply(handle, specs[0])
|
|
|
|
from obliteratus.strategies.utils import get_layer_modules, get_ffn_module
|
|
ffn = get_ffn_module(get_layer_modules(handle)[0], handle.architecture)
|
|
for param in ffn.parameters():
|
|
assert torch.all(param == 0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Embedding ablation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEmbeddingAblation:
|
|
def test_enumerate(self, handle):
|
|
strat = get_strategy("embedding_ablation")
|
|
specs = strat.enumerate(handle)
|
|
assert len(specs) > 0
|
|
|
|
def test_apply_zeros_dims(self, handle):
|
|
strat = get_strategy("embedding_ablation")
|
|
spec = AblationSpec(
|
|
strategy_name="embedding_ablation",
|
|
component="embed_dims_0_4",
|
|
description="test",
|
|
metadata={"dim_start": 0, "dim_end": 4},
|
|
)
|
|
strat.apply(handle, spec)
|
|
|
|
from obliteratus.strategies.utils import get_embedding_module
|
|
emb = get_embedding_module(handle)
|
|
assert torch.all(emb.weight[:, 0:4] == 0)
|