Files
OBLITERATUS/obliteratus/config.py
T
2026-03-04 12:38:18 -08:00

118 lines
3.7 KiB
Python

"""YAML-based configuration for ablation runs."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
@dataclass
class ModelConfig:
name: str
task: str = "causal_lm"
dtype: str = "float32"
device: str = "auto"
trust_remote_code: bool = False
num_labels: int = 2
@dataclass
class DatasetConfig:
name: str
split: str = "test"
subset: str | None = None
text_column: str = "text"
label_column: str = "label"
max_samples: int | None = None
@dataclass
class StrategyConfig:
name: str
params: dict[str, Any] = field(default_factory=dict)
@dataclass
class StudyConfig:
"""Top-level configuration for an ablation run."""
model: ModelConfig
dataset: DatasetConfig
strategies: list[StrategyConfig]
metrics: list[str] = field(default_factory=lambda: ["perplexity"])
batch_size: int = 8
max_length: int = 512
output_dir: str = "results"
@classmethod
def from_yaml(cls, path: str | Path) -> StudyConfig:
path = Path(path)
raw = yaml.safe_load(path.read_text())
return cls.from_dict(raw)
@classmethod
def from_dict(cls, d: dict) -> StudyConfig:
# Accept both "preset" and legacy "study_preset" keys.
if "preset" in d and "study_preset" not in d:
d["study_preset"] = d["preset"]
# If a study_preset key is provided, use it as the base and allow
# the rest of the config to override individual fields.
if "study_preset" in d:
from obliteratus.study_presets import get_study_preset
preset = get_study_preset(d["study_preset"])
# Preset provides defaults; explicit keys in the dict override.
if "strategies" not in d:
d["strategies"] = preset.strategies
if "metrics" not in d:
d["metrics"] = preset.metrics
if "batch_size" not in d:
d["batch_size"] = preset.batch_size
if "max_length" not in d:
d["max_length"] = preset.max_length
# Preset max_samples flows into dataset if not set
ds = d.get("dataset", {})
if "max_samples" not in ds and ds:
ds["max_samples"] = preset.max_samples
d["dataset"] = ds
model = ModelConfig(**d["model"])
dataset = DatasetConfig(**d["dataset"])
strategies = [StrategyConfig(**s) for s in d["strategies"]]
return cls(
model=model,
dataset=dataset,
strategies=strategies,
metrics=d.get("metrics", ["perplexity"]),
batch_size=d.get("batch_size", 8),
max_length=d.get("max_length", 512),
output_dir=d.get("output_dir", "results"),
)
def to_dict(self) -> dict:
return {
"model": {
"name": self.model.name,
"task": self.model.task,
"dtype": self.model.dtype,
"device": self.model.device,
"trust_remote_code": self.model.trust_remote_code,
},
"dataset": {
"name": self.dataset.name,
"split": self.dataset.split,
"subset": self.dataset.subset,
"text_column": self.dataset.text_column,
"label_column": self.dataset.label_column,
"max_samples": self.dataset.max_samples,
},
"strategies": [{"name": s.name, "params": s.params} for s in self.strategies],
"metrics": self.metrics,
"batch_size": self.batch_size,
"max_length": self.max_length,
"output_dir": self.output_dir,
}