mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-29 06:35:59 +02:00
508 lines
18 KiB
Python
508 lines
18 KiB
Python
"""CLI entry point for Obliteratus — Master Ablation Suite."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
|
|
_BANNER = r"""
|
|
[bold red]
|
|
░▒█▀▀▀█ ░▒█▀▀▄ ░▒█░░░ ▀█▀ ▀▀█▀▀ ░▒█▀▀▀ ░▒█▀▀█ ▒█▀▀█ ▀▀█▀▀ ░▒█░░▒█ ░▒█▀▀▀█
|
|
░▒█░░▒█ ░▒█▀▀▄ ░▒█░░░ ░█░ ░░█░░ ░▒█▀▀▀ ░▒█▄▄▀ ▒█▄▄█ ░░█░░ ░▒█░░▒█ ░░▀▀▀▄▄
|
|
░▒█▄▄▄█ ░▒█▄▄▀ ░▒█▄▄█ ▄█▄ ░░▀░░ ░▒█▄▄▄ ░▒█░▒█ ▒█░▒█ ░░▀░░ ░░▒█▄▄█ ░▒█▄▄▄█
|
|
[/bold red]
|
|
[dim] ════════════════════════════════════════════════════════════════════[/dim]
|
|
[bold white] MASTER ABLATION SUITE[/bold white] [dim]//[/dim] [bold red]Break the chains. Free the mind.[/bold red]
|
|
[dim] ════════════════════════════════════════════════════════════════════[/dim]
|
|
"""
|
|
|
|
|
|
def main(argv: list[str] | None = None):
|
|
console.print(_BANNER)
|
|
parser = argparse.ArgumentParser(
|
|
prog="obliteratus",
|
|
description="Master Ablation Suite for HuggingFace transformers",
|
|
)
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# --- run ---
|
|
run_parser = subparsers.add_parser("run", help="Run an ablation from a YAML config")
|
|
run_parser.add_argument("config", type=str, help="Path to YAML config file")
|
|
run_parser.add_argument("--output-dir", type=str, default=None, help="Override output dir")
|
|
run_parser.add_argument(
|
|
"--preset",
|
|
type=str,
|
|
default=None,
|
|
help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)",
|
|
)
|
|
|
|
# --- info ---
|
|
info_parser = subparsers.add_parser("info", help="Print model architecture info")
|
|
info_parser.add_argument("model", type=str, help="HuggingFace model name/path")
|
|
info_parser.add_argument("--task", type=str, default="causal_lm", choices=["causal_lm", "classification"])
|
|
info_parser.add_argument("--device", type=str, default="cpu")
|
|
info_parser.add_argument("--dtype", type=str, default="float32")
|
|
|
|
# --- interactive ---
|
|
subparsers.add_parser(
|
|
"interactive",
|
|
help="Guided setup — pick hardware, model, and preset interactively",
|
|
)
|
|
|
|
# --- models ---
|
|
models_parser = subparsers.add_parser("models", help="Browse curated models by compute tier")
|
|
models_parser.add_argument(
|
|
"--tier",
|
|
type=str,
|
|
default=None,
|
|
choices=["tiny", "small", "medium", "large", "frontier"],
|
|
help="Filter by compute tier",
|
|
)
|
|
|
|
# --- presets ---
|
|
subparsers.add_parser("presets", help="Browse ablation presets (quick, full, jailbreak, etc.)")
|
|
|
|
# --- strategies ---
|
|
subparsers.add_parser("strategies", help="List available ablation strategies")
|
|
|
|
# --- ui ---
|
|
ui_parser = subparsers.add_parser(
|
|
"ui",
|
|
help="Launch the Gradio web UI locally (same UI as the HuggingFace Space)",
|
|
)
|
|
ui_parser.add_argument(
|
|
"--port", type=int, default=7860, help="Server port (default: 7860)",
|
|
)
|
|
ui_parser.add_argument(
|
|
"--host", type=str, default="0.0.0.0", help="Server host (default: 0.0.0.0)",
|
|
)
|
|
ui_parser.add_argument(
|
|
"--share", action="store_true", help="Create a public Gradio share link",
|
|
)
|
|
ui_parser.add_argument(
|
|
"--no-browser", action="store_true", help="Don't auto-open browser on launch",
|
|
)
|
|
ui_parser.add_argument(
|
|
"--auth", type=str, default=None,
|
|
help="Basic auth as user:pass",
|
|
)
|
|
ui_parser.add_argument(
|
|
"--quiet", action="store_true", help="Suppress the startup banner",
|
|
)
|
|
|
|
# --- obliterate (primary) + abliterate (backward-compat alias) ---
|
|
def _add_obliterate_args(p):
|
|
p.add_argument("model", type=str, help="HuggingFace model name/path")
|
|
p.add_argument("--output-dir", type=str, default=None, help="Where to save the obliterated model")
|
|
p.add_argument("--device", type=str, default="auto")
|
|
p.add_argument("--dtype", type=str, default="float16")
|
|
p.add_argument(
|
|
"--method", type=str, default="advanced",
|
|
choices=[
|
|
"basic", "advanced", "aggressive", "spectral_cascade",
|
|
"informed", "surgical", "optimized", "inverted", "nuclear",
|
|
],
|
|
help="Liberation method (default: advanced)",
|
|
)
|
|
p.add_argument("--n-directions", type=int, default=None, help="Override: number of SVD directions to extract")
|
|
p.add_argument("--regularization", type=float, default=None, help="Override: fraction to preserve (0.0-1.0)")
|
|
p.add_argument("--refinement-passes", type=int, default=None, help="Override: number of iterative passes")
|
|
p.add_argument(
|
|
"--quantization", type=str, default=None, choices=["4bit", "8bit"],
|
|
help="Load model with quantization (4bit or 8bit). Requires bitsandbytes.",
|
|
)
|
|
p.add_argument(
|
|
"--large-model", action="store_true", default=False,
|
|
help="Enable conservative defaults for 120B+ models (fewer directions, 1 pass, lower SAE expansion).",
|
|
)
|
|
p.add_argument(
|
|
"--verify-sample-size", type=int, default=None,
|
|
help="Number of harmful prompts to test for refusal rate (default: 30). "
|
|
"Increase for tighter confidence intervals (e.g. 100 for ~1%% resolution).",
|
|
)
|
|
p.add_argument(
|
|
"--contribute", action="store_true", default=False,
|
|
help="Save a community contribution record after the run completes.",
|
|
)
|
|
p.add_argument(
|
|
"--contribute-notes", type=str, default="",
|
|
help="Optional notes to include with the community contribution.",
|
|
)
|
|
|
|
abl_parser = subparsers.add_parser(
|
|
"obliterate",
|
|
help="One-click: remove refusal directions from a model (SOTA multi-technique)",
|
|
)
|
|
_add_obliterate_args(abl_parser)
|
|
# Backward-compat alias (hidden from help)
|
|
abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS)
|
|
_add_obliterate_args(abl_alias)
|
|
|
|
# --- report ---
|
|
report_parser = subparsers.add_parser("report", help="Regenerate report from saved results")
|
|
report_parser.add_argument("results_json", type=str, help="Path to results.json")
|
|
report_parser.add_argument("--output-dir", type=str, default=None)
|
|
|
|
# --- aggregate ---
|
|
aggregate_parser = subparsers.add_parser("aggregate", help="Aggregate community contribution results")
|
|
aggregate_parser.add_argument(
|
|
"--dir", type=str, default="community_results",
|
|
help="Directory containing contribution JSON files",
|
|
)
|
|
|
|
args = parser.parse_args(argv)
|
|
|
|
if args.command == "run":
|
|
_cmd_run(args)
|
|
elif args.command == "interactive":
|
|
_cmd_interactive()
|
|
elif args.command == "models":
|
|
_cmd_models(args)
|
|
elif args.command == "presets":
|
|
_cmd_presets()
|
|
elif args.command == "info":
|
|
_cmd_info(args)
|
|
elif args.command == "strategies":
|
|
_cmd_strategies()
|
|
elif args.command == "report":
|
|
_cmd_report(args)
|
|
elif args.command == "aggregate":
|
|
_cmd_aggregate(args)
|
|
elif args.command == "ui":
|
|
_cmd_ui(args)
|
|
elif args.command in ("obliterate", "abliterate"):
|
|
_cmd_abliterate(args)
|
|
|
|
|
|
def _cmd_ui(args):
|
|
from obliteratus.local_ui import launch_local_ui
|
|
|
|
auth = tuple(args.auth.split(":", 1)) if args.auth else None
|
|
launch_local_ui(
|
|
host=args.host,
|
|
port=args.port,
|
|
share=args.share,
|
|
open_browser=not args.no_browser,
|
|
auth=auth,
|
|
quiet=args.quiet,
|
|
)
|
|
|
|
|
|
def _cmd_interactive():
|
|
from obliteratus.interactive import run_interactive
|
|
run_interactive()
|
|
|
|
|
|
def _cmd_models(args):
|
|
from rich.table import Table
|
|
|
|
from obliteratus.presets import get_presets_by_tier, list_all_presets
|
|
|
|
presets = get_presets_by_tier(args.tier) if args.tier else list_all_presets()
|
|
|
|
table = Table(title="Model Library — Curated Targets")
|
|
table.add_column("Model", style="green")
|
|
table.add_column("HuggingFace ID", style="cyan")
|
|
table.add_column("Params", justify="right")
|
|
table.add_column("Tier", style="yellow")
|
|
table.add_column("Dtype")
|
|
table.add_column("Quant")
|
|
table.add_column("Description")
|
|
|
|
for p in presets:
|
|
table.add_row(
|
|
p.name,
|
|
p.hf_id,
|
|
p.params,
|
|
p.tier.upper(),
|
|
p.recommended_dtype,
|
|
p.recommended_quantization or "—",
|
|
p.description,
|
|
)
|
|
|
|
console.print(table)
|
|
console.print(
|
|
"\n[dim]Tiers: TINY = CPU/laptop | SMALL = 4-8GB | "
|
|
"MEDIUM = 8-16GB | LARGE = 24GB+ | FRONTIER = multi-GPU/cloud[/dim]"
|
|
)
|
|
|
|
|
|
def _cmd_presets():
|
|
from rich.table import Table
|
|
|
|
from obliteratus.study_presets import list_study_presets
|
|
|
|
presets = list_study_presets()
|
|
|
|
table = Table(title="Ablation Presets")
|
|
table.add_column("Key", style="cyan", min_width=12)
|
|
table.add_column("Name", style="green")
|
|
table.add_column("Strategies", style="yellow")
|
|
table.add_column("Samples", justify="right")
|
|
table.add_column("Description", max_width=55)
|
|
|
|
for p in presets:
|
|
strats = ", ".join(s["name"] for s in p.strategies)
|
|
table.add_row(p.key, p.name, strats, str(p.max_samples), p.description)
|
|
|
|
console.print(table)
|
|
console.print(
|
|
"\n[dim]Usage: obliteratus run config.yaml --preset quick\n"
|
|
" or: set preset: quick in your YAML file[/dim]"
|
|
)
|
|
|
|
|
|
def _cmd_run(args):
|
|
from obliteratus.config import StudyConfig
|
|
from obliteratus.runner import run_study
|
|
|
|
config = StudyConfig.from_yaml(args.config)
|
|
# If --preset flag given, inject it so from_dict picks it up
|
|
if args.preset:
|
|
import yaml
|
|
|
|
raw = yaml.safe_load(Path(args.config).read_text())
|
|
raw["preset"] = args.preset
|
|
config = StudyConfig.from_dict(raw)
|
|
if args.output_dir:
|
|
config.output_dir = args.output_dir
|
|
run_study(config)
|
|
|
|
|
|
def _cmd_info(args):
|
|
from obliteratus.models.loader import load_model
|
|
|
|
console.print(f"[bold cyan]Loading model:[/bold cyan] {args.model}")
|
|
handle = load_model(
|
|
model_name=args.model,
|
|
task=args.task,
|
|
device=args.device,
|
|
dtype=args.dtype,
|
|
)
|
|
summary = handle.summary()
|
|
for key, val in summary.items():
|
|
if isinstance(val, int) and val > 1000:
|
|
console.print(f" {key}: {val:,}")
|
|
else:
|
|
console.print(f" {key}: {val}")
|
|
|
|
|
|
def _cmd_strategies():
|
|
from obliteratus.strategies import STRATEGY_REGISTRY
|
|
|
|
console.print("[bold]Available ablation strategies:[/bold]\n")
|
|
for name, cls in sorted(STRATEGY_REGISTRY.items()):
|
|
doc = (cls.__doc__ or "").strip().split("\n")[0]
|
|
console.print(f" [cyan]{name}[/cyan] — {doc}")
|
|
|
|
|
|
def _cmd_report(args):
|
|
from obliteratus.reporting.report import AblationReport, AblationResult
|
|
|
|
path = Path(args.results_json)
|
|
data = json.loads(path.read_text())
|
|
|
|
report = AblationReport(model_name=data["model_name"])
|
|
report.add_baseline(data["baseline_metrics"])
|
|
for r in data["results"]:
|
|
report.add_result(
|
|
AblationResult(
|
|
strategy=r["strategy"],
|
|
component=r["component"],
|
|
description=r["description"],
|
|
metrics=r["metrics"],
|
|
metadata=r.get("metadata"),
|
|
)
|
|
)
|
|
|
|
report.print_summary()
|
|
|
|
output_dir = Path(args.output_dir) if args.output_dir else path.parent
|
|
metric_name = list(data["baseline_metrics"].keys())[0]
|
|
try:
|
|
report.plot_impact(metric=metric_name, output_path=output_dir / "impact.png")
|
|
report.plot_heatmap(output_path=output_dir / "heatmap.png")
|
|
console.print(f"\nPlots saved to {output_dir}/")
|
|
except Exception as e:
|
|
console.print(f"[yellow]Could not generate plots: {e}[/yellow]")
|
|
|
|
|
|
def _cmd_aggregate(args):
|
|
from obliteratus.community import aggregate_results, load_contributions
|
|
|
|
contrib_dir = args.dir
|
|
records = load_contributions(contrib_dir)
|
|
if not records:
|
|
console.print(f"[yellow]No contributions found in {contrib_dir}[/yellow]")
|
|
return
|
|
|
|
aggregated = aggregate_results(records)
|
|
|
|
from rich.table import Table
|
|
|
|
table = Table(title="Aggregated Community Results")
|
|
table.add_column("Model", style="green")
|
|
table.add_column("Method", style="cyan")
|
|
table.add_column("Runs", justify="right")
|
|
table.add_column("Mean Refusal", justify="right")
|
|
table.add_column("Mean Perplexity", justify="right")
|
|
|
|
for model_name, methods in sorted(aggregated.items()):
|
|
for method_name, stats in sorted(methods.items()):
|
|
refusal = stats.get("refusal_rate", {}).get("mean", "N/A")
|
|
ppl = stats.get("perplexity", {}).get("mean", "N/A")
|
|
if isinstance(refusal, float):
|
|
refusal = f"{refusal:.4f}"
|
|
if isinstance(ppl, float):
|
|
ppl = f"{ppl:.2f}"
|
|
table.add_row(
|
|
model_name.split("/")[-1] if "/" in model_name else model_name,
|
|
method_name,
|
|
str(stats["n_runs"]),
|
|
str(refusal),
|
|
str(ppl),
|
|
)
|
|
|
|
console.print(table)
|
|
|
|
|
|
def _cmd_abliterate(args):
|
|
from rich.live import Live
|
|
from rich.panel import Panel
|
|
from rich.table import Table
|
|
from rich.text import Text
|
|
|
|
from obliteratus.abliterate import METHODS, STAGES, AbliterationPipeline
|
|
|
|
model_name = args.model
|
|
output_dir = args.output_dir or f"abliterated/{model_name.replace('/', '_')}"
|
|
method = args.method
|
|
method_label = METHODS.get(method, {}).get("label", method)
|
|
|
|
# Stage state tracking
|
|
stage_status = {s.key: "waiting" for s in STAGES}
|
|
stage_msgs = {s.key: "" for s in STAGES}
|
|
log_lines: list[str] = []
|
|
|
|
def make_display():
|
|
table = Table(show_header=False, expand=True, border_style="green")
|
|
table.add_column("", width=6)
|
|
table.add_column("Stage", min_width=10)
|
|
table.add_column("Status", min_width=50)
|
|
for i, s in enumerate(STAGES):
|
|
st = stage_status[s.key]
|
|
if st == "done":
|
|
icon = "[bold green]✓[/]"
|
|
bar = "[green]" + "█" * 20 + "[/]"
|
|
elif st == "running":
|
|
icon = "[bold yellow]⚡[/]"
|
|
bar = "[yellow]" + "▓" * 10 + "░" * 10 + "[/]"
|
|
else:
|
|
icon = "[dim]○[/]"
|
|
bar = "[dim]" + "░" * 20 + "[/]"
|
|
msg = stage_msgs.get(s.key, "")
|
|
table.add_row(
|
|
f"[cyan][{i + 1}/6][/]",
|
|
f"{icon} [bold]{s.name}[/]",
|
|
f"{bar} {msg}",
|
|
)
|
|
|
|
header = Text.from_markup(
|
|
f"[bold green]OBLITERATUS — ABLITERATION PIPELINE[/]\n"
|
|
f"[dim]Target:[/] [cyan]{model_name}[/] → [cyan]{output_dir}[/]\n"
|
|
f"[dim]Method:[/] [magenta]{method_label}[/]"
|
|
)
|
|
|
|
# Last 12 log lines
|
|
recent = log_lines[-12:] if log_lines else ["Initializing..."]
|
|
log_text = "\n".join(f"[dim]>[/] {line}" for line in recent)
|
|
|
|
return Panel(
|
|
f"{header}\n\n{table}\n\n[dim]─── LOG ───[/]\n{log_text}",
|
|
border_style="green",
|
|
title="[bold green]⚗ ABLITERATE ⚗[/]",
|
|
)
|
|
|
|
def on_stage(result):
|
|
stage_status[result.stage] = result.status
|
|
stage_msgs[result.stage] = result.message
|
|
if live:
|
|
live.update(make_display())
|
|
|
|
def on_log(msg):
|
|
log_lines.append(msg)
|
|
if live:
|
|
live.update(make_display())
|
|
|
|
live = None
|
|
pipeline = AbliterationPipeline(
|
|
model_name=model_name,
|
|
output_dir=output_dir,
|
|
device=args.device,
|
|
dtype=args.dtype,
|
|
method=method,
|
|
n_directions=args.n_directions,
|
|
regularization=args.regularization,
|
|
refinement_passes=args.refinement_passes,
|
|
quantization=args.quantization,
|
|
large_model_mode=getattr(args, "large_model", False),
|
|
verify_sample_size=getattr(args, "verify_sample_size", None),
|
|
on_stage=on_stage,
|
|
on_log=on_log,
|
|
)
|
|
|
|
with Live(make_display(), console=console, refresh_per_second=4) as live_ctx:
|
|
live = live_ctx
|
|
try:
|
|
result_path = pipeline.run()
|
|
live.update(make_display())
|
|
except Exception as e:
|
|
log_lines.append(f"[red]ERROR: {e}[/]")
|
|
live.update(make_display())
|
|
raise
|
|
|
|
# ── Telemetry: send pipeline report to community leaderboard ──
|
|
try:
|
|
from obliteratus.telemetry import maybe_send_pipeline_report
|
|
maybe_send_pipeline_report(pipeline)
|
|
except Exception:
|
|
pass # Telemetry is best-effort
|
|
|
|
# ── Community contribution (--contribute flag) ──
|
|
contrib_path = None
|
|
if getattr(args, "contribute", False):
|
|
try:
|
|
from obliteratus.community import save_contribution
|
|
contrib_path = save_contribution(
|
|
pipeline,
|
|
model_name=model_name,
|
|
notes=getattr(args, "contribute_notes", ""),
|
|
)
|
|
except Exception as e:
|
|
console.print(f"[yellow]Could not save contribution: {e}[/yellow]")
|
|
|
|
console.print()
|
|
contrib_line = ""
|
|
if contrib_path:
|
|
contrib_line = f"\n Contribution: [cyan]{contrib_path}[/]"
|
|
console.print(
|
|
Panel(
|
|
f"[bold green]Abliteration complete![/]\n\n"
|
|
f" Model saved to: [cyan]{result_path}[/]\n"
|
|
f" Metadata: [cyan]{result_path}/abliteration_metadata.json[/]"
|
|
f"{contrib_line}\n\n"
|
|
f" [dim]Load with:[/] AutoModelForCausalLM.from_pretrained('{result_path}')",
|
|
border_style="green",
|
|
title="[bold green]✓ REBIRTH COMPLETE[/]",
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|