Files
OBLITERATUS/obliteratus/cli.py
T
2026-03-04 13:09:57 -08:00

508 lines
18 KiB
Python

"""CLI entry point for Obliteratus — Master Ablation Suite."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from rich.console import Console
console = Console()
_BANNER = r"""
[bold red]
░▒█▀▀▀█ ░▒█▀▀▄ ░▒█░░░ ▀█▀ ▀▀█▀▀ ░▒█▀▀▀ ░▒█▀▀█ ▒█▀▀█ ▀▀█▀▀ ░▒█░░▒█ ░▒█▀▀▀█
░▒█░░▒█ ░▒█▀▀▄ ░▒█░░░ ░█░ ░░█░░ ░▒█▀▀▀ ░▒█▄▄▀ ▒█▄▄█ ░░█░░ ░▒█░░▒█ ░░▀▀▀▄▄
░▒█▄▄▄█ ░▒█▄▄▀ ░▒█▄▄█ ▄█▄ ░░▀░░ ░▒█▄▄▄ ░▒█░▒█ ▒█░▒█ ░░▀░░ ░░▒█▄▄█ ░▒█▄▄▄█
[/bold red]
[dim] ════════════════════════════════════════════════════════════════════[/dim]
[bold white] MASTER ABLATION SUITE[/bold white] [dim]//[/dim] [bold red]Break the chains. Free the mind.[/bold red]
[dim] ════════════════════════════════════════════════════════════════════[/dim]
"""
def main(argv: list[str] | None = None):
console.print(_BANNER)
parser = argparse.ArgumentParser(
prog="obliteratus",
description="Master Ablation Suite for HuggingFace transformers",
)
subparsers = parser.add_subparsers(dest="command", required=True)
# --- run ---
run_parser = subparsers.add_parser("run", help="Run an ablation from a YAML config")
run_parser.add_argument("config", type=str, help="Path to YAML config file")
run_parser.add_argument("--output-dir", type=str, default=None, help="Override output dir")
run_parser.add_argument(
"--preset",
type=str,
default=None,
help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)",
)
# --- info ---
info_parser = subparsers.add_parser("info", help="Print model architecture info")
info_parser.add_argument("model", type=str, help="HuggingFace model name/path")
info_parser.add_argument("--task", type=str, default="causal_lm", choices=["causal_lm", "classification"])
info_parser.add_argument("--device", type=str, default="cpu")
info_parser.add_argument("--dtype", type=str, default="float32")
# --- interactive ---
subparsers.add_parser(
"interactive",
help="Guided setup — pick hardware, model, and preset interactively",
)
# --- models ---
models_parser = subparsers.add_parser("models", help="Browse curated models by compute tier")
models_parser.add_argument(
"--tier",
type=str,
default=None,
choices=["tiny", "small", "medium", "large", "frontier"],
help="Filter by compute tier",
)
# --- presets ---
subparsers.add_parser("presets", help="Browse ablation presets (quick, full, jailbreak, etc.)")
# --- strategies ---
subparsers.add_parser("strategies", help="List available ablation strategies")
# --- ui ---
ui_parser = subparsers.add_parser(
"ui",
help="Launch the Gradio web UI locally (same UI as the HuggingFace Space)",
)
ui_parser.add_argument(
"--port", type=int, default=7860, help="Server port (default: 7860)",
)
ui_parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Server host (default: 0.0.0.0)",
)
ui_parser.add_argument(
"--share", action="store_true", help="Create a public Gradio share link",
)
ui_parser.add_argument(
"--no-browser", action="store_true", help="Don't auto-open browser on launch",
)
ui_parser.add_argument(
"--auth", type=str, default=None,
help="Basic auth as user:pass",
)
ui_parser.add_argument(
"--quiet", action="store_true", help="Suppress the startup banner",
)
# --- obliterate (primary) + abliterate (backward-compat alias) ---
def _add_obliterate_args(p):
p.add_argument("model", type=str, help="HuggingFace model name/path")
p.add_argument("--output-dir", type=str, default=None, help="Where to save the obliterated model")
p.add_argument("--device", type=str, default="auto")
p.add_argument("--dtype", type=str, default="float16")
p.add_argument(
"--method", type=str, default="advanced",
choices=[
"basic", "advanced", "aggressive", "spectral_cascade",
"informed", "surgical", "optimized", "inverted", "nuclear",
],
help="Liberation method (default: advanced)",
)
p.add_argument("--n-directions", type=int, default=None, help="Override: number of SVD directions to extract")
p.add_argument("--regularization", type=float, default=None, help="Override: fraction to preserve (0.0-1.0)")
p.add_argument("--refinement-passes", type=int, default=None, help="Override: number of iterative passes")
p.add_argument(
"--quantization", type=str, default=None, choices=["4bit", "8bit"],
help="Load model with quantization (4bit or 8bit). Requires bitsandbytes.",
)
p.add_argument(
"--large-model", action="store_true", default=False,
help="Enable conservative defaults for 120B+ models (fewer directions, 1 pass, lower SAE expansion).",
)
p.add_argument(
"--verify-sample-size", type=int, default=None,
help="Number of harmful prompts to test for refusal rate (default: 30). "
"Increase for tighter confidence intervals (e.g. 100 for ~1%% resolution).",
)
p.add_argument(
"--contribute", action="store_true", default=False,
help="Save a community contribution record after the run completes.",
)
p.add_argument(
"--contribute-notes", type=str, default="",
help="Optional notes to include with the community contribution.",
)
abl_parser = subparsers.add_parser(
"obliterate",
help="One-click: remove refusal directions from a model (SOTA multi-technique)",
)
_add_obliterate_args(abl_parser)
# Backward-compat alias (hidden from help)
abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS)
_add_obliterate_args(abl_alias)
# --- report ---
report_parser = subparsers.add_parser("report", help="Regenerate report from saved results")
report_parser.add_argument("results_json", type=str, help="Path to results.json")
report_parser.add_argument("--output-dir", type=str, default=None)
# --- aggregate ---
aggregate_parser = subparsers.add_parser("aggregate", help="Aggregate community contribution results")
aggregate_parser.add_argument(
"--dir", type=str, default="community_results",
help="Directory containing contribution JSON files",
)
args = parser.parse_args(argv)
if args.command == "run":
_cmd_run(args)
elif args.command == "interactive":
_cmd_interactive()
elif args.command == "models":
_cmd_models(args)
elif args.command == "presets":
_cmd_presets()
elif args.command == "info":
_cmd_info(args)
elif args.command == "strategies":
_cmd_strategies()
elif args.command == "report":
_cmd_report(args)
elif args.command == "aggregate":
_cmd_aggregate(args)
elif args.command == "ui":
_cmd_ui(args)
elif args.command in ("obliterate", "abliterate"):
_cmd_abliterate(args)
def _cmd_ui(args):
from obliteratus.local_ui import launch_local_ui
auth = tuple(args.auth.split(":", 1)) if args.auth else None
launch_local_ui(
host=args.host,
port=args.port,
share=args.share,
open_browser=not args.no_browser,
auth=auth,
quiet=args.quiet,
)
def _cmd_interactive():
from obliteratus.interactive import run_interactive
run_interactive()
def _cmd_models(args):
from rich.table import Table
from obliteratus.presets import get_presets_by_tier, list_all_presets
presets = get_presets_by_tier(args.tier) if args.tier else list_all_presets()
table = Table(title="Model Library — Curated Targets")
table.add_column("Model", style="green")
table.add_column("HuggingFace ID", style="cyan")
table.add_column("Params", justify="right")
table.add_column("Tier", style="yellow")
table.add_column("Dtype")
table.add_column("Quant")
table.add_column("Description")
for p in presets:
table.add_row(
p.name,
p.hf_id,
p.params,
p.tier.upper(),
p.recommended_dtype,
p.recommended_quantization or "",
p.description,
)
console.print(table)
console.print(
"\n[dim]Tiers: TINY = CPU/laptop | SMALL = 4-8GB | "
"MEDIUM = 8-16GB | LARGE = 24GB+ | FRONTIER = multi-GPU/cloud[/dim]"
)
def _cmd_presets():
from rich.table import Table
from obliteratus.study_presets import list_study_presets
presets = list_study_presets()
table = Table(title="Ablation Presets")
table.add_column("Key", style="cyan", min_width=12)
table.add_column("Name", style="green")
table.add_column("Strategies", style="yellow")
table.add_column("Samples", justify="right")
table.add_column("Description", max_width=55)
for p in presets:
strats = ", ".join(s["name"] for s in p.strategies)
table.add_row(p.key, p.name, strats, str(p.max_samples), p.description)
console.print(table)
console.print(
"\n[dim]Usage: obliteratus run config.yaml --preset quick\n"
" or: set preset: quick in your YAML file[/dim]"
)
def _cmd_run(args):
from obliteratus.config import StudyConfig
from obliteratus.runner import run_study
config = StudyConfig.from_yaml(args.config)
# If --preset flag given, inject it so from_dict picks it up
if args.preset:
import yaml
raw = yaml.safe_load(Path(args.config).read_text())
raw["preset"] = args.preset
config = StudyConfig.from_dict(raw)
if args.output_dir:
config.output_dir = args.output_dir
run_study(config)
def _cmd_info(args):
from obliteratus.models.loader import load_model
console.print(f"[bold cyan]Loading model:[/bold cyan] {args.model}")
handle = load_model(
model_name=args.model,
task=args.task,
device=args.device,
dtype=args.dtype,
)
summary = handle.summary()
for key, val in summary.items():
if isinstance(val, int) and val > 1000:
console.print(f" {key}: {val:,}")
else:
console.print(f" {key}: {val}")
def _cmd_strategies():
from obliteratus.strategies import STRATEGY_REGISTRY
console.print("[bold]Available ablation strategies:[/bold]\n")
for name, cls in sorted(STRATEGY_REGISTRY.items()):
doc = (cls.__doc__ or "").strip().split("\n")[0]
console.print(f" [cyan]{name}[/cyan] — {doc}")
def _cmd_report(args):
from obliteratus.reporting.report import AblationReport, AblationResult
path = Path(args.results_json)
data = json.loads(path.read_text())
report = AblationReport(model_name=data["model_name"])
report.add_baseline(data["baseline_metrics"])
for r in data["results"]:
report.add_result(
AblationResult(
strategy=r["strategy"],
component=r["component"],
description=r["description"],
metrics=r["metrics"],
metadata=r.get("metadata"),
)
)
report.print_summary()
output_dir = Path(args.output_dir) if args.output_dir else path.parent
metric_name = list(data["baseline_metrics"].keys())[0]
try:
report.plot_impact(metric=metric_name, output_path=output_dir / "impact.png")
report.plot_heatmap(output_path=output_dir / "heatmap.png")
console.print(f"\nPlots saved to {output_dir}/")
except Exception as e:
console.print(f"[yellow]Could not generate plots: {e}[/yellow]")
def _cmd_aggregate(args):
from obliteratus.community import aggregate_results, load_contributions
contrib_dir = args.dir
records = load_contributions(contrib_dir)
if not records:
console.print(f"[yellow]No contributions found in {contrib_dir}[/yellow]")
return
aggregated = aggregate_results(records)
from rich.table import Table
table = Table(title="Aggregated Community Results")
table.add_column("Model", style="green")
table.add_column("Method", style="cyan")
table.add_column("Runs", justify="right")
table.add_column("Mean Refusal", justify="right")
table.add_column("Mean Perplexity", justify="right")
for model_name, methods in sorted(aggregated.items()):
for method_name, stats in sorted(methods.items()):
refusal = stats.get("refusal_rate", {}).get("mean", "N/A")
ppl = stats.get("perplexity", {}).get("mean", "N/A")
if isinstance(refusal, float):
refusal = f"{refusal:.4f}"
if isinstance(ppl, float):
ppl = f"{ppl:.2f}"
table.add_row(
model_name.split("/")[-1] if "/" in model_name else model_name,
method_name,
str(stats["n_runs"]),
str(refusal),
str(ppl),
)
console.print(table)
def _cmd_abliterate(args):
from rich.live import Live
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from obliteratus.abliterate import METHODS, STAGES, AbliterationPipeline
model_name = args.model
output_dir = args.output_dir or f"abliterated/{model_name.replace('/', '_')}"
method = args.method
method_label = METHODS.get(method, {}).get("label", method)
# Stage state tracking
stage_status = {s.key: "waiting" for s in STAGES}
stage_msgs = {s.key: "" for s in STAGES}
log_lines: list[str] = []
def make_display():
table = Table(show_header=False, expand=True, border_style="green")
table.add_column("", width=6)
table.add_column("Stage", min_width=10)
table.add_column("Status", min_width=50)
for i, s in enumerate(STAGES):
st = stage_status[s.key]
if st == "done":
icon = "[bold green]✓[/]"
bar = "[green]" + "" * 20 + "[/]"
elif st == "running":
icon = "[bold yellow]⚡[/]"
bar = "[yellow]" + "" * 10 + "" * 10 + "[/]"
else:
icon = "[dim]○[/]"
bar = "[dim]" + "" * 20 + "[/]"
msg = stage_msgs.get(s.key, "")
table.add_row(
f"[cyan][{i + 1}/6][/]",
f"{icon} [bold]{s.name}[/]",
f"{bar} {msg}",
)
header = Text.from_markup(
f"[bold green]OBLITERATUS — ABLITERATION PIPELINE[/]\n"
f"[dim]Target:[/] [cyan]{model_name}[/] → [cyan]{output_dir}[/]\n"
f"[dim]Method:[/] [magenta]{method_label}[/]"
)
# Last 12 log lines
recent = log_lines[-12:] if log_lines else ["Initializing..."]
log_text = "\n".join(f"[dim]>[/] {line}" for line in recent)
return Panel(
f"{header}\n\n{table}\n\n[dim]─── LOG ───[/]\n{log_text}",
border_style="green",
title="[bold green]⚗ ABLITERATE ⚗[/]",
)
def on_stage(result):
stage_status[result.stage] = result.status
stage_msgs[result.stage] = result.message
if live:
live.update(make_display())
def on_log(msg):
log_lines.append(msg)
if live:
live.update(make_display())
live = None
pipeline = AbliterationPipeline(
model_name=model_name,
output_dir=output_dir,
device=args.device,
dtype=args.dtype,
method=method,
n_directions=args.n_directions,
regularization=args.regularization,
refinement_passes=args.refinement_passes,
quantization=args.quantization,
large_model_mode=getattr(args, "large_model", False),
verify_sample_size=getattr(args, "verify_sample_size", None),
on_stage=on_stage,
on_log=on_log,
)
with Live(make_display(), console=console, refresh_per_second=4) as live_ctx:
live = live_ctx
try:
result_path = pipeline.run()
live.update(make_display())
except Exception as e:
log_lines.append(f"[red]ERROR: {e}[/]")
live.update(make_display())
raise
# ── Telemetry: send pipeline report to community leaderboard ──
try:
from obliteratus.telemetry import maybe_send_pipeline_report
maybe_send_pipeline_report(pipeline)
except Exception:
pass # Telemetry is best-effort
# ── Community contribution (--contribute flag) ──
contrib_path = None
if getattr(args, "contribute", False):
try:
from obliteratus.community import save_contribution
contrib_path = save_contribution(
pipeline,
model_name=model_name,
notes=getattr(args, "contribute_notes", ""),
)
except Exception as e:
console.print(f"[yellow]Could not save contribution: {e}[/yellow]")
console.print()
contrib_line = ""
if contrib_path:
contrib_line = f"\n Contribution: [cyan]{contrib_path}[/]"
console.print(
Panel(
f"[bold green]Abliteration complete![/]\n\n"
f" Model saved to: [cyan]{result_path}[/]\n"
f" Metadata: [cyan]{result_path}/abliteration_metadata.json[/]"
f"{contrib_line}\n\n"
f" [dim]Load with:[/] AutoModelForCausalLM.from_pretrained('{result_path}')",
border_style="green",
title="[bold green]✓ REBIRTH COMPLETE[/]",
)
)
if __name__ == "__main__":
main()