mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-23 19:56:15 +02:00
60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
"""Tests for configuration loading."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import yaml
|
|
|
|
from obliteratus.config import StudyConfig
|
|
|
|
|
|
SAMPLE_CONFIG = {
|
|
"model": {
|
|
"name": "gpt2",
|
|
"task": "causal_lm",
|
|
"dtype": "float32",
|
|
"device": "cpu",
|
|
},
|
|
"dataset": {
|
|
"name": "wikitext",
|
|
"subset": "wikitext-2-raw-v1",
|
|
"split": "test",
|
|
"text_column": "text",
|
|
"max_samples": 50,
|
|
},
|
|
"strategies": [
|
|
{"name": "layer_removal", "params": {}},
|
|
{"name": "ffn_ablation", "params": {}},
|
|
],
|
|
"metrics": ["perplexity"],
|
|
"batch_size": 4,
|
|
"max_length": 256,
|
|
"output_dir": "results/test",
|
|
}
|
|
|
|
|
|
class TestStudyConfig:
|
|
def test_from_dict(self):
|
|
config = StudyConfig.from_dict(SAMPLE_CONFIG)
|
|
assert config.model.name == "gpt2"
|
|
assert config.model.task == "causal_lm"
|
|
assert config.dataset.name == "wikitext"
|
|
assert len(config.strategies) == 2
|
|
assert config.strategies[0].name == "layer_removal"
|
|
|
|
def test_from_yaml(self, tmp_path):
|
|
yaml_path = tmp_path / "test_config.yaml"
|
|
yaml_path.write_text(yaml.dump(SAMPLE_CONFIG))
|
|
|
|
config = StudyConfig.from_yaml(yaml_path)
|
|
assert config.model.name == "gpt2"
|
|
assert config.batch_size == 4
|
|
|
|
def test_roundtrip(self):
|
|
config = StudyConfig.from_dict(SAMPLE_CONFIG)
|
|
d = config.to_dict()
|
|
config2 = StudyConfig.from_dict(d)
|
|
assert config2.model.name == config.model.name
|
|
assert config2.dataset.name == config.dataset.name
|
|
assert len(config2.strategies) == len(config.strategies)
|