mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-23 19:56:15 +02:00
109 lines
4.3 KiB
Python
109 lines
4.3 KiB
Python
"""Tests for ablation presets."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from obliteratus.study_presets import (
|
|
STUDY_PRESETS,
|
|
get_study_preset,
|
|
get_preset,
|
|
list_study_presets,
|
|
list_presets,
|
|
)
|
|
from obliteratus.config import StudyConfig
|
|
|
|
|
|
class TestPresets:
|
|
def test_all_presets_registered(self):
|
|
expected_keys = {"quick", "full", "attention", "layers", "knowledge", "pruning", "embeddings", "jailbreak", "guardrail", "robustness"}
|
|
assert expected_keys.issubset(set(STUDY_PRESETS.keys()))
|
|
|
|
def test_get_preset(self):
|
|
preset = get_study_preset("quick")
|
|
assert preset.name == "Quick Scan"
|
|
assert preset.key == "quick"
|
|
assert len(preset.strategies) == 2
|
|
|
|
def test_get_preset_alias(self):
|
|
preset = get_preset("quick")
|
|
assert preset.name == "Quick Scan"
|
|
|
|
def test_get_unknown_preset_raises(self):
|
|
import pytest
|
|
with pytest.raises(KeyError, match="Unknown preset"):
|
|
get_study_preset("nonexistent")
|
|
|
|
def test_list_presets(self):
|
|
presets = list_study_presets()
|
|
assert len(presets) >= 7
|
|
keys = [p.key for p in presets]
|
|
assert "quick" in keys
|
|
assert "full" in keys
|
|
|
|
def test_list_presets_alias(self):
|
|
assert list_presets() == list_study_presets()
|
|
|
|
def test_preset_strategies_are_valid(self):
|
|
from obliteratus.strategies import STRATEGY_REGISTRY
|
|
for preset in list_study_presets():
|
|
for s in preset.strategies:
|
|
assert s["name"] in STRATEGY_REGISTRY, (
|
|
f"Preset {preset.key!r} references unknown strategy {s['name']!r}"
|
|
)
|
|
|
|
|
|
class TestConfigWithPreset:
|
|
def test_preset_key_in_config(self):
|
|
config_dict = {
|
|
"preset": "quick",
|
|
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
|
|
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
|
|
}
|
|
config = StudyConfig.from_dict(config_dict)
|
|
# Should inherit strategies from the quick preset
|
|
assert len(config.strategies) == 2
|
|
strategy_names = [s.name for s in config.strategies]
|
|
assert "layer_removal" in strategy_names
|
|
assert "ffn_ablation" in strategy_names
|
|
# Should inherit max_samples
|
|
assert config.dataset.max_samples == 25
|
|
# Should inherit batch_size and max_length
|
|
assert config.batch_size == 4
|
|
assert config.max_length == 128
|
|
|
|
def test_legacy_study_preset_key_still_works(self):
|
|
config_dict = {
|
|
"study_preset": "quick",
|
|
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
|
|
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
|
|
}
|
|
config = StudyConfig.from_dict(config_dict)
|
|
assert len(config.strategies) == 2
|
|
|
|
def test_preset_can_be_overridden(self):
|
|
config_dict = {
|
|
"preset": "quick",
|
|
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
|
|
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text", "max_samples": 999},
|
|
"batch_size": 16,
|
|
"strategies": [{"name": "head_pruning", "params": {}}],
|
|
}
|
|
config = StudyConfig.from_dict(config_dict)
|
|
# Explicit strategies should override preset
|
|
assert len(config.strategies) == 1
|
|
assert config.strategies[0].name == "head_pruning"
|
|
# Explicit batch_size should override
|
|
assert config.batch_size == 16
|
|
# Explicit max_samples in dataset should be kept
|
|
assert config.dataset.max_samples == 999
|
|
|
|
def test_full_preset(self):
|
|
config_dict = {
|
|
"preset": "full",
|
|
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
|
|
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
|
|
}
|
|
config = StudyConfig.from_dict(config_dict)
|
|
assert len(config.strategies) == 4
|
|
strategy_names = {s.name for s in config.strategies}
|
|
assert strategy_names == {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"}
|