mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-06-09 07:43:57 +02:00
118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
"""YAML-based configuration for ablation runs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
name: str
|
|
task: str = "causal_lm"
|
|
dtype: str = "float32"
|
|
device: str = "auto"
|
|
trust_remote_code: bool = False
|
|
num_labels: int = 2
|
|
|
|
|
|
@dataclass
|
|
class DatasetConfig:
|
|
name: str
|
|
split: str = "test"
|
|
subset: str | None = None
|
|
text_column: str = "text"
|
|
label_column: str = "label"
|
|
max_samples: int | None = None
|
|
|
|
|
|
@dataclass
|
|
class StrategyConfig:
|
|
name: str
|
|
params: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class StudyConfig:
|
|
"""Top-level configuration for an ablation run."""
|
|
|
|
model: ModelConfig
|
|
dataset: DatasetConfig
|
|
strategies: list[StrategyConfig]
|
|
metrics: list[str] = field(default_factory=lambda: ["perplexity"])
|
|
batch_size: int = 8
|
|
max_length: int = 512
|
|
output_dir: str = "results"
|
|
|
|
@classmethod
|
|
def from_yaml(cls, path: str | Path) -> StudyConfig:
|
|
path = Path(path)
|
|
raw = yaml.safe_load(path.read_text())
|
|
return cls.from_dict(raw)
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> StudyConfig:
|
|
# Accept both "preset" and legacy "study_preset" keys.
|
|
if "preset" in d and "study_preset" not in d:
|
|
d["study_preset"] = d["preset"]
|
|
# If a study_preset key is provided, use it as the base and allow
|
|
# the rest of the config to override individual fields.
|
|
if "study_preset" in d:
|
|
from obliteratus.study_presets import get_study_preset
|
|
|
|
preset = get_study_preset(d["study_preset"])
|
|
# Preset provides defaults; explicit keys in the dict override.
|
|
if "strategies" not in d:
|
|
d["strategies"] = preset.strategies
|
|
if "metrics" not in d:
|
|
d["metrics"] = preset.metrics
|
|
if "batch_size" not in d:
|
|
d["batch_size"] = preset.batch_size
|
|
if "max_length" not in d:
|
|
d["max_length"] = preset.max_length
|
|
# Preset max_samples flows into dataset if not set
|
|
ds = d.get("dataset", {})
|
|
if "max_samples" not in ds and ds:
|
|
ds["max_samples"] = preset.max_samples
|
|
d["dataset"] = ds
|
|
|
|
model = ModelConfig(**d["model"])
|
|
dataset = DatasetConfig(**d["dataset"])
|
|
strategies = [StrategyConfig(**s) for s in d["strategies"]]
|
|
return cls(
|
|
model=model,
|
|
dataset=dataset,
|
|
strategies=strategies,
|
|
metrics=d.get("metrics", ["perplexity"]),
|
|
batch_size=d.get("batch_size", 8),
|
|
max_length=d.get("max_length", 512),
|
|
output_dir=d.get("output_dir", "results"),
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"model": {
|
|
"name": self.model.name,
|
|
"task": self.model.task,
|
|
"dtype": self.model.dtype,
|
|
"device": self.model.device,
|
|
"trust_remote_code": self.model.trust_remote_code,
|
|
},
|
|
"dataset": {
|
|
"name": self.dataset.name,
|
|
"split": self.dataset.split,
|
|
"subset": self.dataset.subset,
|
|
"text_column": self.dataset.text_column,
|
|
"label_column": self.dataset.label_column,
|
|
"max_samples": self.dataset.max_samples,
|
|
},
|
|
"strategies": [{"name": s.name, "params": s.params} for s in self.strategies],
|
|
"metrics": self.metrics,
|
|
"batch_size": self.batch_size,
|
|
"max_length": self.max_length,
|
|
"output_dir": self.output_dir,
|
|
}
|