Files
2026-03-07 17:54:38 -08:00

329 lines
11 KiB
Python

"""Interactive guided mode for non-technical users.
Run with: obliteratus interactive
"""
from __future__ import annotations
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Prompt, IntPrompt, Confirm
from obliteratus.presets import (
ModelPreset,
get_presets_by_tier,
)
console = Console()
def _detect_compute_tier() -> str:
"""Auto-detect the best compute tier based on available hardware."""
try:
from obliteratus import device as dev
if dev.is_cuda():
import torch
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
if vram_gb >= 20:
return "large"
elif vram_gb >= 8:
return "medium"
else:
return "small"
elif dev.is_mps():
# Apple Silicon with unified memory — estimate from system RAM
mem = dev.get_memory_info()
if mem.total_gb >= 24:
return "medium" # M1 Pro/Max/Ultra, M2 Pro/Max/Ultra, M3 Pro/Max
else:
return "small" # M1/M2/M3 base (8-16 GB)
except ImportError:
pass
return "tiny" # CPU only
def _pick_compute_tier() -> str:
"""Let the user choose their compute tier with auto-detection."""
detected = _detect_compute_tier()
console.print()
console.print(
Panel(
"[bold]What hardware are you working with?[/bold]\n\n"
" [cyan]1)[/cyan] [green]No GPU / basic laptop[/green] — CPU only, < 8GB RAM\n"
" [cyan]2)[/cyan] [green]Basic GPU[/green] — 4-8 GB VRAM (GTX 1060, RTX 3050, etc.)\n"
" [cyan]3)[/cyan] [green]Mid-range GPU[/green] — 8-16 GB VRAM (RTX 3060/4060/4070)\n"
" [cyan]4)[/cyan] [green]High-end GPU[/green] — 24+ GB VRAM (RTX 3090/4090, A100)\n",
title="Step 1: Hardware",
)
)
tier_map = {"1": "tiny", "2": "small", "3": "medium", "4": "large"}
detected_num = {"tiny": "1", "small": "2", "medium": "3", "large": "4"}[detected]
choice = Prompt.ask(
f" Your choice (auto-detected: [bold]{detected_num}[/bold])",
choices=["1", "2", "3", "4"],
default=detected_num,
)
return tier_map[choice]
def _pick_model(tier: str) -> ModelPreset:
"""Show models for the chosen tier and let the user pick."""
presets = get_presets_by_tier(tier)
# Also show one tier below as safe options
tier_order = ["tiny", "small", "medium", "large"]
idx = tier_order.index(tier)
if idx > 0:
presets = get_presets_by_tier(tier_order[idx - 1]) + presets
console.print()
table = Table(title="Recommended models for your hardware")
table.add_column("#", style="cyan", justify="right")
table.add_column("Model", style="green")
table.add_column("Params", justify="right")
table.add_column("Tier", style="yellow")
table.add_column("Description")
for i, p in enumerate(presets, 1):
table.add_row(str(i), p.name, p.params, p.tier.upper(), p.description)
console.print(table)
choice = IntPrompt.ask(
"\n Pick a model number (or 0 to enter a custom HuggingFace model ID)",
default=1,
)
if choice == 0:
custom_id = Prompt.ask(" Enter HuggingFace model ID (e.g. 'gpt2')")
return ModelPreset(
name=custom_id,
hf_id=custom_id,
description="Custom model",
tier=tier,
params="unknown",
recommended_dtype="float16" if tier != "tiny" else "float32",
)
if 1 <= choice <= len(presets):
return presets[choice - 1]
console.print("[red]Invalid choice, using first model.[/red]")
return presets[0]
def _pick_study_preset():
"""Let the user pick an ablation preset or go custom.
Returns a StudyPreset if chosen, or None for custom mode.
"""
from obliteratus.study_presets import list_study_presets
presets = list_study_presets()
console.print()
table = Table(title="Ablation Presets — Pick a recipe or go custom")
table.add_column("#", style="cyan", justify="right")
table.add_column("Name", style="green")
table.add_column("Strategies", style="yellow")
table.add_column("Samples", justify="right")
table.add_column("Description")
for i, p in enumerate(presets, 1):
strats = ", ".join(s["name"] for s in p.strategies)
table.add_row(str(i), p.name, strats, str(p.max_samples), p.description)
table.add_row(
str(len(presets) + 1), "Custom", "pick your own", "",
"Manually choose strategies and settings",
)
console.print(table)
choice = IntPrompt.ask("\n Pick a preset number", default=1)
if 1 <= choice <= len(presets):
return presets[choice - 1]
return None # custom mode
def _pick_strategies() -> list[dict]:
"""Let the user choose which ablation strategies to run (custom mode)."""
console.print()
console.print(
Panel(
"[bold]Which components do you want to test?[/bold]\n\n"
" [cyan]1)[/cyan] [green]Layers[/green] — Remove entire transformer layers one by one\n"
" [cyan]2)[/cyan] [green]Attention heads[/green] — Remove individual attention heads\n"
" [cyan]3)[/cyan] [green]FFN blocks[/green] — Remove feed-forward networks\n"
" [cyan]4)[/cyan] [green]Embeddings[/green] — Zero-out chunks of embedding dimensions\n"
" [cyan]5)[/cyan] [green]All of the above[/green]\n",
title="What to Ablate",
)
)
choice = Prompt.ask(" Your choice", choices=["1", "2", "3", "4", "5"], default="5")
mapping = {
"1": [{"name": "layer_removal", "params": {}}],
"2": [{"name": "head_pruning", "params": {}}],
"3": [{"name": "ffn_ablation", "params": {}}],
"4": [{"name": "embedding_ablation", "params": {"chunk_size": 48}}],
"5": [
{"name": "layer_removal", "params": {}},
{"name": "head_pruning", "params": {}},
{"name": "ffn_ablation", "params": {}},
{"name": "embedding_ablation", "params": {"chunk_size": 48}},
],
}
return mapping[choice]
def _pick_sample_size() -> int:
"""Let the user pick how many samples to evaluate on (custom mode)."""
console.print()
console.print(
Panel(
"[bold]How thorough should the evaluation be?[/bold]\n\n"
" [cyan]1)[/cyan] [green]Quick[/green] — 25 samples (fast, rough estimate)\n"
" [cyan]2)[/cyan] [green]Standard[/green] — 100 samples (good balance)\n"
" [cyan]3)[/cyan] [green]Thorough[/green] — 500 samples (slower, more accurate)\n",
title="Evaluation Depth",
)
)
choice = Prompt.ask(" Your choice", choices=["1", "2", "3"], default="2")
return {"1": 25, "2": 100, "3": 500}[choice]
def run_interactive():
"""Main interactive flow — walks the user through setting up and running an ablation."""
console.print()
console.print(
Panel.fit(
"[bold white on blue] OBLITERATUS — Master Ablation Suite [/bold white on blue]\n\n"
"This tool helps you understand which parts of an AI model\n"
"are most important by systematically removing components\n"
"and measuring the impact on performance.\n\n"
"[dim]No coding required — just answer a few questions.[/dim]",
)
)
# Step 1: Hardware
tier = _pick_compute_tier()
console.print(f"\n [bold]Selected tier:[/bold] {tier.upper()}")
# Step 2: Model
model_preset = _pick_model(tier)
console.print(f"\n [bold]Selected model:[/bold] {model_preset.name} ({model_preset.hf_id})")
# Step 3: Study preset OR custom strategies + sample size
study_preset = _pick_study_preset()
if study_preset is not None:
console.print(f"\n [bold]Preset:[/bold] {study_preset.name}")
strategies = study_preset.strategies
max_samples = study_preset.max_samples
batch_size = study_preset.batch_size
max_length = study_preset.max_length
else:
strategies = _pick_strategies()
max_samples = _pick_sample_size()
batch_size = 4 if tier in ("tiny", "small") else 8
max_length = 256
strategy_names = [s["name"] for s in strategies]
console.print(f" [bold]Strategies:[/bold] {', '.join(strategy_names)}")
# Step 4: Determine device and dtype
device = "cpu"
dtype = model_preset.recommended_dtype
quantization = None
try:
from obliteratus import device as _dev
resolved = _dev.get_device()
if resolved != "cpu":
device = resolved if resolved == "mps" else "auto"
except ImportError:
pass
if model_preset.recommended_quantization and device != "cpu":
if Confirm.ask(
f"\n Use {model_preset.recommended_quantization} quantization? (saves memory)",
default=True,
):
quantization = model_preset.recommended_quantization
# Build config
from obliteratus.config import StudyConfig, ModelConfig, DatasetConfig, StrategyConfig
model_cfg = ModelConfig(
name=model_preset.hf_id,
task="causal_lm",
dtype=dtype,
device=device,
trust_remote_code=True,
)
dataset_cfg = DatasetConfig(
name="wikitext",
subset="wikitext-2-raw-v1",
split="test",
text_column="text",
max_samples=max_samples,
)
strategy_cfgs = [StrategyConfig(name=s["name"], params=s.get("params", {})) for s in strategies]
config = StudyConfig(
model=model_cfg,
dataset=dataset_cfg,
strategies=strategy_cfgs,
metrics=["perplexity"],
batch_size=batch_size,
max_length=max_length,
output_dir=f"results/{model_preset.hf_id.replace('/', '_')}",
)
# Confirmation
preset_label = f" (preset: {study_preset.name})" if study_preset else " (custom)"
console.print()
console.print(Panel(
f"[bold]Model:[/bold] {model_preset.name} ({model_preset.hf_id})\n"
f"[bold]Device:[/bold] {device} ({dtype})"
+ (f" + {quantization}" if quantization else "")
+ f"\n[bold]Dataset:[/bold] wikitext-2 ({max_samples} samples)\n"
f"[bold]Ablation:[/bold] {', '.join(strategy_names)}{preset_label}\n"
f"[bold]Output:[/bold] {config.output_dir}/",
title="Run Configuration",
))
if not Confirm.ask("\n Ready to start?", default=True):
console.print("[yellow]Cancelled.[/yellow]")
return None
# Handle quantization by modifying the loader
if quantization:
_run_quantized(config, quantization)
else:
from obliteratus.runner import run_study
return run_study(config)
def _run_quantized(config, quantization: str):
"""Run ablation with quantized model loading."""
from obliteratus.runner import run_study
# Patch the model loading to use bitsandbytes quantization
console.print(f"\n[bold yellow]Note:[/bold yellow] Loading with {quantization} quantization...")
console.print(" Make sure 'bitsandbytes' is installed: pip install bitsandbytes\n")
# For quantized models, we modify the config device to auto (needed for bitsandbytes)
config.model.device = "auto"
config.model.quantization = quantization
return run_study(config)