Files
2026-03-04 12:38:18 -08:00

129 lines
4.6 KiB
Python

"""Main ablation runner — orchestrates the full pipeline."""
from __future__ import annotations
from pathlib import Path
from datasets import load_dataset
from rich.console import Console
from obliteratus.config import StudyConfig
from obliteratus.evaluation.evaluator import Evaluator
from obliteratus.models.loader import load_model
from obliteratus.reporting.report import AblationReport, AblationResult
from obliteratus.strategies import get_strategy
console = Console()
def run_study(config: StudyConfig) -> AblationReport:
"""Execute a full ablation study from a StudyConfig.
Steps:
1. Load model from HuggingFace.
2. Load evaluation dataset.
3. Compute baseline metrics.
4. For each strategy, enumerate ablation specs and evaluate each.
5. Collect everything into an AblationReport.
"""
output_dir = Path(config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# --- 1. Load model ---
console.print(f"\n[bold cyan]Loading model:[/bold cyan] {config.model.name}")
handle = load_model(
model_name=config.model.name,
task=config.model.task,
device=config.model.device,
dtype=config.model.dtype,
trust_remote_code=config.model.trust_remote_code,
num_labels=config.model.num_labels,
)
console.print(f" Architecture: {handle.architecture}")
console.print(f" Layers: {handle.num_layers} Heads: {handle.num_heads}")
console.print(f" Hidden: {handle.hidden_size} Params: {handle.summary()['total_params']:,}")
# --- 2. Load dataset ---
console.print(f"\n[bold cyan]Loading dataset:[/bold cyan] {config.dataset.name}")
ds_kwargs = {"path": config.dataset.name, "split": config.dataset.split}
if config.dataset.subset:
ds_kwargs["name"] = config.dataset.subset
dataset = load_dataset(**ds_kwargs)
console.print(f" Samples: {len(dataset)}")
# --- 3. Baseline evaluation ---
console.print("\n[bold green]Computing baseline metrics...[/bold green]")
evaluator = Evaluator(
handle=handle,
dataset=dataset,
metrics=config.metrics,
batch_size=config.batch_size,
max_length=config.max_length,
max_samples=config.dataset.max_samples,
text_column=config.dataset.text_column,
label_column=config.dataset.label_column,
)
baseline = evaluator.evaluate()
console.print(f" Baseline: {baseline}")
report = AblationReport(model_name=config.model.name)
report.add_baseline(baseline)
# --- 4. Run ablation strategies ---
for strat_cfg in config.strategies:
console.print(f"\n[bold magenta]Strategy:[/bold magenta] {strat_cfg.name}")
strategy = get_strategy(strat_cfg.name)
specs = strategy.enumerate(handle, **strat_cfg.params)
console.print(f" Ablation specs: {len(specs)}")
for spec in specs:
console.print(f" [dim]Ablating {spec.component}...[/dim]", end=" ")
# Apply ablation
strategy.apply(handle, spec)
# Evaluate
ablated_eval = Evaluator(
handle=handle,
dataset=dataset,
metrics=config.metrics,
batch_size=config.batch_size,
max_length=config.max_length,
max_samples=config.dataset.max_samples,
text_column=config.dataset.text_column,
label_column=config.dataset.label_column,
)
metrics = ablated_eval.evaluate()
console.print(f"{metrics}")
report.add_result(
AblationResult(
strategy=spec.strategy_name,
component=spec.component,
description=spec.description,
metrics=metrics,
metadata=spec.metadata,
)
)
# Restore model
handle.restore()
# --- 5. Save outputs ---
report.save_json(output_dir / "results.json")
report.save_csv(output_dir / "results.csv")
# Try to generate plots (may fail in headless environments)
try:
metric_name = config.metrics[0]
report.plot_impact(metric=metric_name, output_path=output_dir / "impact.png")
report.plot_heatmap(output_path=output_dir / "heatmap.png")
console.print(f"\n[bold]Plots saved to {output_dir}/[/bold]")
except Exception as e:
console.print(f"\n[yellow]Could not generate plots: {e}[/yellow]")
console.print(f"\n[bold green]Results saved to {output_dir}/[/bold green]")
report.print_summary()
return report