mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-29 06:35:59 +02:00
329 lines
11 KiB
Python
329 lines
11 KiB
Python
"""Interactive guided mode for non-technical users.
|
|
|
|
Run with: obliteratus interactive
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.table import Table
|
|
from rich.prompt import Prompt, IntPrompt, Confirm
|
|
|
|
from obliteratus.presets import (
|
|
ModelPreset,
|
|
get_presets_by_tier,
|
|
)
|
|
|
|
console = Console()
|
|
|
|
|
|
def _detect_compute_tier() -> str:
|
|
"""Auto-detect the best compute tier based on available hardware."""
|
|
try:
|
|
from obliteratus import device as dev
|
|
|
|
if dev.is_cuda():
|
|
import torch
|
|
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
if vram_gb >= 20:
|
|
return "large"
|
|
elif vram_gb >= 8:
|
|
return "medium"
|
|
else:
|
|
return "small"
|
|
elif dev.is_mps():
|
|
# Apple Silicon with unified memory — estimate from system RAM
|
|
mem = dev.get_memory_info()
|
|
if mem.total_gb >= 24:
|
|
return "medium" # M1 Pro/Max/Ultra, M2 Pro/Max/Ultra, M3 Pro/Max
|
|
else:
|
|
return "small" # M1/M2/M3 base (8-16 GB)
|
|
except ImportError:
|
|
pass
|
|
return "tiny" # CPU only
|
|
|
|
|
|
def _pick_compute_tier() -> str:
|
|
"""Let the user choose their compute tier with auto-detection."""
|
|
detected = _detect_compute_tier()
|
|
|
|
console.print()
|
|
console.print(
|
|
Panel(
|
|
"[bold]What hardware are you working with?[/bold]\n\n"
|
|
" [cyan]1)[/cyan] [green]No GPU / basic laptop[/green] — CPU only, < 8GB RAM\n"
|
|
" [cyan]2)[/cyan] [green]Basic GPU[/green] — 4-8 GB VRAM (GTX 1060, RTX 3050, etc.)\n"
|
|
" [cyan]3)[/cyan] [green]Mid-range GPU[/green] — 8-16 GB VRAM (RTX 3060/4060/4070)\n"
|
|
" [cyan]4)[/cyan] [green]High-end GPU[/green] — 24+ GB VRAM (RTX 3090/4090, A100)\n",
|
|
title="Step 1: Hardware",
|
|
)
|
|
)
|
|
|
|
tier_map = {"1": "tiny", "2": "small", "3": "medium", "4": "large"}
|
|
detected_num = {"tiny": "1", "small": "2", "medium": "3", "large": "4"}[detected]
|
|
|
|
choice = Prompt.ask(
|
|
f" Your choice (auto-detected: [bold]{detected_num}[/bold])",
|
|
choices=["1", "2", "3", "4"],
|
|
default=detected_num,
|
|
)
|
|
return tier_map[choice]
|
|
|
|
|
|
def _pick_model(tier: str) -> ModelPreset:
|
|
"""Show models for the chosen tier and let the user pick."""
|
|
presets = get_presets_by_tier(tier)
|
|
# Also show one tier below as safe options
|
|
tier_order = ["tiny", "small", "medium", "large"]
|
|
idx = tier_order.index(tier)
|
|
if idx > 0:
|
|
presets = get_presets_by_tier(tier_order[idx - 1]) + presets
|
|
|
|
console.print()
|
|
table = Table(title="Recommended models for your hardware")
|
|
table.add_column("#", style="cyan", justify="right")
|
|
table.add_column("Model", style="green")
|
|
table.add_column("Params", justify="right")
|
|
table.add_column("Tier", style="yellow")
|
|
table.add_column("Description")
|
|
|
|
for i, p in enumerate(presets, 1):
|
|
table.add_row(str(i), p.name, p.params, p.tier.upper(), p.description)
|
|
|
|
console.print(table)
|
|
|
|
choice = IntPrompt.ask(
|
|
"\n Pick a model number (or 0 to enter a custom HuggingFace model ID)",
|
|
default=1,
|
|
)
|
|
|
|
if choice == 0:
|
|
custom_id = Prompt.ask(" Enter HuggingFace model ID (e.g. 'gpt2')")
|
|
return ModelPreset(
|
|
name=custom_id,
|
|
hf_id=custom_id,
|
|
description="Custom model",
|
|
tier=tier,
|
|
params="unknown",
|
|
recommended_dtype="float16" if tier != "tiny" else "float32",
|
|
)
|
|
|
|
if 1 <= choice <= len(presets):
|
|
return presets[choice - 1]
|
|
|
|
console.print("[red]Invalid choice, using first model.[/red]")
|
|
return presets[0]
|
|
|
|
|
|
def _pick_study_preset():
|
|
"""Let the user pick an ablation preset or go custom.
|
|
|
|
Returns a StudyPreset if chosen, or None for custom mode.
|
|
"""
|
|
from obliteratus.study_presets import list_study_presets
|
|
|
|
presets = list_study_presets()
|
|
|
|
console.print()
|
|
table = Table(title="Ablation Presets — Pick a recipe or go custom")
|
|
table.add_column("#", style="cyan", justify="right")
|
|
table.add_column("Name", style="green")
|
|
table.add_column("Strategies", style="yellow")
|
|
table.add_column("Samples", justify="right")
|
|
table.add_column("Description")
|
|
|
|
for i, p in enumerate(presets, 1):
|
|
strats = ", ".join(s["name"] for s in p.strategies)
|
|
table.add_row(str(i), p.name, strats, str(p.max_samples), p.description)
|
|
table.add_row(
|
|
str(len(presets) + 1), "Custom", "pick your own", "—",
|
|
"Manually choose strategies and settings",
|
|
)
|
|
|
|
console.print(table)
|
|
|
|
choice = IntPrompt.ask("\n Pick a preset number", default=1)
|
|
|
|
if 1 <= choice <= len(presets):
|
|
return presets[choice - 1]
|
|
return None # custom mode
|
|
|
|
|
|
def _pick_strategies() -> list[dict]:
|
|
"""Let the user choose which ablation strategies to run (custom mode)."""
|
|
console.print()
|
|
console.print(
|
|
Panel(
|
|
"[bold]Which components do you want to test?[/bold]\n\n"
|
|
" [cyan]1)[/cyan] [green]Layers[/green] — Remove entire transformer layers one by one\n"
|
|
" [cyan]2)[/cyan] [green]Attention heads[/green] — Remove individual attention heads\n"
|
|
" [cyan]3)[/cyan] [green]FFN blocks[/green] — Remove feed-forward networks\n"
|
|
" [cyan]4)[/cyan] [green]Embeddings[/green] — Zero-out chunks of embedding dimensions\n"
|
|
" [cyan]5)[/cyan] [green]All of the above[/green]\n",
|
|
title="What to Ablate",
|
|
)
|
|
)
|
|
|
|
choice = Prompt.ask(" Your choice", choices=["1", "2", "3", "4", "5"], default="5")
|
|
|
|
mapping = {
|
|
"1": [{"name": "layer_removal", "params": {}}],
|
|
"2": [{"name": "head_pruning", "params": {}}],
|
|
"3": [{"name": "ffn_ablation", "params": {}}],
|
|
"4": [{"name": "embedding_ablation", "params": {"chunk_size": 48}}],
|
|
"5": [
|
|
{"name": "layer_removal", "params": {}},
|
|
{"name": "head_pruning", "params": {}},
|
|
{"name": "ffn_ablation", "params": {}},
|
|
{"name": "embedding_ablation", "params": {"chunk_size": 48}},
|
|
],
|
|
}
|
|
return mapping[choice]
|
|
|
|
|
|
def _pick_sample_size() -> int:
|
|
"""Let the user pick how many samples to evaluate on (custom mode)."""
|
|
console.print()
|
|
console.print(
|
|
Panel(
|
|
"[bold]How thorough should the evaluation be?[/bold]\n\n"
|
|
" [cyan]1)[/cyan] [green]Quick[/green] — 25 samples (fast, rough estimate)\n"
|
|
" [cyan]2)[/cyan] [green]Standard[/green] — 100 samples (good balance)\n"
|
|
" [cyan]3)[/cyan] [green]Thorough[/green] — 500 samples (slower, more accurate)\n",
|
|
title="Evaluation Depth",
|
|
)
|
|
)
|
|
|
|
choice = Prompt.ask(" Your choice", choices=["1", "2", "3"], default="2")
|
|
return {"1": 25, "2": 100, "3": 500}[choice]
|
|
|
|
|
|
def run_interactive():
|
|
"""Main interactive flow — walks the user through setting up and running an ablation."""
|
|
console.print()
|
|
console.print(
|
|
Panel.fit(
|
|
"[bold white on blue] OBLITERATUS — Master Ablation Suite [/bold white on blue]\n\n"
|
|
"This tool helps you understand which parts of an AI model\n"
|
|
"are most important by systematically removing components\n"
|
|
"and measuring the impact on performance.\n\n"
|
|
"[dim]No coding required — just answer a few questions.[/dim]",
|
|
)
|
|
)
|
|
|
|
# Step 1: Hardware
|
|
tier = _pick_compute_tier()
|
|
console.print(f"\n [bold]Selected tier:[/bold] {tier.upper()}")
|
|
|
|
# Step 2: Model
|
|
model_preset = _pick_model(tier)
|
|
console.print(f"\n [bold]Selected model:[/bold] {model_preset.name} ({model_preset.hf_id})")
|
|
|
|
# Step 3: Study preset OR custom strategies + sample size
|
|
study_preset = _pick_study_preset()
|
|
|
|
if study_preset is not None:
|
|
console.print(f"\n [bold]Preset:[/bold] {study_preset.name}")
|
|
strategies = study_preset.strategies
|
|
max_samples = study_preset.max_samples
|
|
batch_size = study_preset.batch_size
|
|
max_length = study_preset.max_length
|
|
else:
|
|
strategies = _pick_strategies()
|
|
max_samples = _pick_sample_size()
|
|
batch_size = 4 if tier in ("tiny", "small") else 8
|
|
max_length = 256
|
|
|
|
strategy_names = [s["name"] for s in strategies]
|
|
console.print(f" [bold]Strategies:[/bold] {', '.join(strategy_names)}")
|
|
|
|
# Step 4: Determine device and dtype
|
|
device = "cpu"
|
|
dtype = model_preset.recommended_dtype
|
|
quantization = None
|
|
try:
|
|
from obliteratus import device as _dev
|
|
|
|
resolved = _dev.get_device()
|
|
if resolved != "cpu":
|
|
device = resolved if resolved == "mps" else "auto"
|
|
except ImportError:
|
|
pass
|
|
|
|
if model_preset.recommended_quantization and device != "cpu":
|
|
if Confirm.ask(
|
|
f"\n Use {model_preset.recommended_quantization} quantization? (saves memory)",
|
|
default=True,
|
|
):
|
|
quantization = model_preset.recommended_quantization
|
|
|
|
# Build config
|
|
from obliteratus.config import StudyConfig, ModelConfig, DatasetConfig, StrategyConfig
|
|
|
|
model_cfg = ModelConfig(
|
|
name=model_preset.hf_id,
|
|
task="causal_lm",
|
|
dtype=dtype,
|
|
device=device,
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
dataset_cfg = DatasetConfig(
|
|
name="wikitext",
|
|
subset="wikitext-2-raw-v1",
|
|
split="test",
|
|
text_column="text",
|
|
max_samples=max_samples,
|
|
)
|
|
|
|
strategy_cfgs = [StrategyConfig(name=s["name"], params=s.get("params", {})) for s in strategies]
|
|
|
|
config = StudyConfig(
|
|
model=model_cfg,
|
|
dataset=dataset_cfg,
|
|
strategies=strategy_cfgs,
|
|
metrics=["perplexity"],
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
output_dir=f"results/{model_preset.hf_id.replace('/', '_')}",
|
|
)
|
|
|
|
# Confirmation
|
|
preset_label = f" (preset: {study_preset.name})" if study_preset else " (custom)"
|
|
console.print()
|
|
console.print(Panel(
|
|
f"[bold]Model:[/bold] {model_preset.name} ({model_preset.hf_id})\n"
|
|
f"[bold]Device:[/bold] {device} ({dtype})"
|
|
+ (f" + {quantization}" if quantization else "")
|
|
+ f"\n[bold]Dataset:[/bold] wikitext-2 ({max_samples} samples)\n"
|
|
f"[bold]Ablation:[/bold] {', '.join(strategy_names)}{preset_label}\n"
|
|
f"[bold]Output:[/bold] {config.output_dir}/",
|
|
title="Run Configuration",
|
|
))
|
|
|
|
if not Confirm.ask("\n Ready to start?", default=True):
|
|
console.print("[yellow]Cancelled.[/yellow]")
|
|
return None
|
|
|
|
# Handle quantization by modifying the loader
|
|
if quantization:
|
|
_run_quantized(config, quantization)
|
|
else:
|
|
from obliteratus.runner import run_study
|
|
return run_study(config)
|
|
|
|
|
|
def _run_quantized(config, quantization: str):
|
|
"""Run ablation with quantized model loading."""
|
|
from obliteratus.runner import run_study
|
|
|
|
# Patch the model loading to use bitsandbytes quantization
|
|
console.print(f"\n[bold yellow]Note:[/bold yellow] Loading with {quantization} quantization...")
|
|
console.print(" Make sure 'bitsandbytes' is installed: pip install bitsandbytes\n")
|
|
|
|
# For quantized models, we modify the config device to auto (needed for bitsandbytes)
|
|
config.model.device = "auto"
|
|
config.model.quantization = quantization
|
|
return run_study(config)
|