mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-29 22:47:50 +02:00
129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
"""Main ablation runner — orchestrates the full pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
from datasets import load_dataset
|
|
from rich.console import Console
|
|
|
|
from obliteratus.config import StudyConfig
|
|
from obliteratus.evaluation.evaluator import Evaluator
|
|
from obliteratus.models.loader import load_model
|
|
from obliteratus.reporting.report import AblationReport, AblationResult
|
|
from obliteratus.strategies import get_strategy
|
|
|
|
console = Console()
|
|
|
|
|
|
def run_study(config: StudyConfig) -> AblationReport:
|
|
"""Execute a full ablation study from a StudyConfig.
|
|
|
|
Steps:
|
|
1. Load model from HuggingFace.
|
|
2. Load evaluation dataset.
|
|
3. Compute baseline metrics.
|
|
4. For each strategy, enumerate ablation specs and evaluate each.
|
|
5. Collect everything into an AblationReport.
|
|
"""
|
|
output_dir = Path(config.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# --- 1. Load model ---
|
|
console.print(f"\n[bold cyan]Loading model:[/bold cyan] {config.model.name}")
|
|
handle = load_model(
|
|
model_name=config.model.name,
|
|
task=config.model.task,
|
|
device=config.model.device,
|
|
dtype=config.model.dtype,
|
|
trust_remote_code=config.model.trust_remote_code,
|
|
num_labels=config.model.num_labels,
|
|
)
|
|
console.print(f" Architecture: {handle.architecture}")
|
|
console.print(f" Layers: {handle.num_layers} Heads: {handle.num_heads}")
|
|
console.print(f" Hidden: {handle.hidden_size} Params: {handle.summary()['total_params']:,}")
|
|
|
|
# --- 2. Load dataset ---
|
|
console.print(f"\n[bold cyan]Loading dataset:[/bold cyan] {config.dataset.name}")
|
|
ds_kwargs = {"path": config.dataset.name, "split": config.dataset.split}
|
|
if config.dataset.subset:
|
|
ds_kwargs["name"] = config.dataset.subset
|
|
dataset = load_dataset(**ds_kwargs)
|
|
console.print(f" Samples: {len(dataset)}")
|
|
|
|
# --- 3. Baseline evaluation ---
|
|
console.print("\n[bold green]Computing baseline metrics...[/bold green]")
|
|
evaluator = Evaluator(
|
|
handle=handle,
|
|
dataset=dataset,
|
|
metrics=config.metrics,
|
|
batch_size=config.batch_size,
|
|
max_length=config.max_length,
|
|
max_samples=config.dataset.max_samples,
|
|
text_column=config.dataset.text_column,
|
|
label_column=config.dataset.label_column,
|
|
)
|
|
baseline = evaluator.evaluate()
|
|
console.print(f" Baseline: {baseline}")
|
|
|
|
report = AblationReport(model_name=config.model.name)
|
|
report.add_baseline(baseline)
|
|
|
|
# --- 4. Run ablation strategies ---
|
|
for strat_cfg in config.strategies:
|
|
console.print(f"\n[bold magenta]Strategy:[/bold magenta] {strat_cfg.name}")
|
|
strategy = get_strategy(strat_cfg.name)
|
|
specs = strategy.enumerate(handle, **strat_cfg.params)
|
|
console.print(f" Ablation specs: {len(specs)}")
|
|
|
|
for spec in specs:
|
|
console.print(f" [dim]Ablating {spec.component}...[/dim]", end=" ")
|
|
|
|
# Apply ablation
|
|
strategy.apply(handle, spec)
|
|
|
|
# Evaluate
|
|
ablated_eval = Evaluator(
|
|
handle=handle,
|
|
dataset=dataset,
|
|
metrics=config.metrics,
|
|
batch_size=config.batch_size,
|
|
max_length=config.max_length,
|
|
max_samples=config.dataset.max_samples,
|
|
text_column=config.dataset.text_column,
|
|
label_column=config.dataset.label_column,
|
|
)
|
|
metrics = ablated_eval.evaluate()
|
|
console.print(f"{metrics}")
|
|
|
|
report.add_result(
|
|
AblationResult(
|
|
strategy=spec.strategy_name,
|
|
component=spec.component,
|
|
description=spec.description,
|
|
metrics=metrics,
|
|
metadata=spec.metadata,
|
|
)
|
|
)
|
|
|
|
# Restore model
|
|
handle.restore()
|
|
|
|
# --- 5. Save outputs ---
|
|
report.save_json(output_dir / "results.json")
|
|
report.save_csv(output_dir / "results.csv")
|
|
|
|
# Try to generate plots (may fail in headless environments)
|
|
try:
|
|
metric_name = config.metrics[0]
|
|
report.plot_impact(metric=metric_name, output_path=output_dir / "impact.png")
|
|
report.plot_heatmap(output_path=output_dir / "heatmap.png")
|
|
console.print(f"\n[bold]Plots saved to {output_dir}/[/bold]")
|
|
except Exception as e:
|
|
console.print(f"\n[yellow]Could not generate plots: {e}[/yellow]")
|
|
|
|
console.print(f"\n[bold green]Results saved to {output_dir}/[/bold green]")
|
|
report.print_summary()
|
|
|
|
return report
|